├── tests ├── __init__.py ├── utils │ ├── __init__.py │ ├── test_transform.py │ └── test_anchors.py ├── backend │ ├── __init__.py │ └── test_common.py ├── layers │ ├── __init__.py │ ├── test_filter_detections.py │ └── test_misc.py ├── models │ ├── __init__.py │ ├── test_densenet.py │ └── test_mobilenet.py ├── preprocessing │ ├── __init__.py │ ├── test_csv_generator.py │ └── test_generator.py ├── test_losses.py └── bin │ └── test_train.py ├── keras_retinanet ├── __init__.py ├── bin │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── convert_model.py │ └── evaluate.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── eval.cpython-36.pyc │ │ ├── colors.cpython-36.pyc │ │ ├── config.cpython-36.pyc │ │ ├── image.cpython-36.pyc │ │ ├── model.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── anchors.cpython-36.pyc │ │ ├── transform.cpython-36.pyc │ │ ├── keras_version.cpython-36.pyc │ │ └── visualization.cpython-36.pyc │ ├── compute_overlap.cpython-36m-x86_64-linux-gnu.so │ ├── model.py │ ├── config.py │ ├── keras_version.py │ ├── compute_overlap.pyx │ ├── colors.py │ ├── coco_eval.py │ ├── visualization.py │ ├── image.py │ └── eval.py ├── preprocessing │ ├── __init__.py │ ├── __pycache__ │ │ ├── coco.cpython-36.pyc │ │ ├── kitti.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── generator.cpython-36.pyc │ │ ├── open_images.cpython-36.pyc │ │ ├── pascal_voc.cpython-36.pyc │ │ └── csv_generator.cpython-36.pyc │ ├── coco.py │ ├── kitti.py │ ├── pascal_voc.py │ └── csv_generator.py ├── callbacks │ ├── __init__.py │ ├── __pycache__ │ │ ├── eval.cpython-36.pyc │ │ ├── common.cpython-36.pyc │ │ └── __init__.cpython-36.pyc │ ├── common.py │ ├── coco.py │ └── eval.py ├── backend │ ├── __init__.py │ ├── __pycache__ │ │ ├── common.cpython-36.pyc │ │ ├── dynamic.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── tensorflow_backend.cpython-36.pyc │ ├── cntk_backend.py │ ├── theano_backend.py │ ├── dynamic.py │ ├── tensorflow_backend.py │ └── common.py ├── __pycache__ │ ├── losses.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── initializers.cpython-36.pyc ├── layers │ ├── __pycache__ │ │ ├── _misc.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── filter_detections.cpython-36.pyc │ ├── __init__.py │ └── _misc.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── retinanet.cpython-36.pyc │ ├── vgg.py │ ├── densenet.py │ ├── mobilenet.py │ ├── __init__.py │ └── resnet.py ├── initializers.py └── losses.py ├── data ├── classes.csv ├── convert.py └── data_count_and_split.py ├── output ├── esmble_jiao.py ├── esmble.py └── test_crop_result.py ├── README ├── test_crop_result.py ├── crop_test.py ├── crop_train.py ├── convert_model.py ├── crop_train_random.py └── test.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * # noqa: F401,F403 2 | -------------------------------------------------------------------------------- /keras_retinanet/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic import * # noqa: F401,F403 2 | from .common import * # noqa: F401,F403 3 | -------------------------------------------------------------------------------- /keras_retinanet/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/__pycache__/initializers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/__pycache__/initializers.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/bin/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/bin/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/layers/__pycache__/_misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/layers/__pycache__/_misc.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/colors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/colors.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/image.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/backend/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/backend/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/backend/__pycache__/dynamic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/backend/__pycache__/dynamic.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/callbacks/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/anchors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/anchors.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/backend/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/backend/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/callbacks/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._misc import RegressBoxes, UpsampleLike, Anchors, ClipBoxes # noqa: F401 2 | from .filter_detections import FilterDetections # noqa: F401 3 | -------------------------------------------------------------------------------- /keras_retinanet/models/__pycache__/retinanet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/models/__pycache__/retinanet.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/callbacks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/kitti.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/kitti.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/keras_version.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/keras_version.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/__pycache__/visualization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/__pycache__/visualization.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/layers/__pycache__/filter_detections.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/layers/__pycache__/filter_detections.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/generator.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/open_images.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/open_images.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/pascal_voc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/pascal_voc.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/backend/__pycache__/tensorflow_backend.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/backend/__pycache__/tensorflow_backend.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__pycache__/csv_generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/preprocessing/__pycache__/csv_generator.cpython-36.pyc -------------------------------------------------------------------------------- /keras_retinanet/utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moyan007/DF-traffic_sign_Detect/HEAD/keras_retinanet/utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /data/classes.csv: -------------------------------------------------------------------------------- 1 | 1,0 2 | 2,1 3 | 3,2 4 | 4,3 5 | 5,4 6 | 6,5 7 | 7,6 8 | 8,7 9 | 9,8 10 | 10,9 11 | 11,10 12 | 12,11 13 | 13,12 14 | 14,13 15 | 15,14 16 | 16,15 17 | 17,16 18 | 18,17 19 | 19,18 20 | 20,19 21 | -------------------------------------------------------------------------------- /keras_retinanet/backend/cntk_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /keras_retinanet/backend/theano_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /keras_retinanet/backend/dynamic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _BACKEND = "tensorflow" 4 | 5 | if "KERAS_BACKEND" in os.environ: 6 | _backend = os.environ["KERAS_BACKEND"] 7 | 8 | backends = { 9 | "cntk", 10 | "tensorflow", 11 | "theano" 12 | } 13 | 14 | assert _backend in backends 15 | 16 | _BACKEND = _backend 17 | 18 | if _BACKEND == "cntk": 19 | from .cntk_backend import * # noqa: F401,F403 20 | elif _BACKEND == "theano": 21 | from .theano_backend import * # noqa: F401,F403 22 | elif _BACKEND == "tensorflow": 23 | from .tensorflow_backend import * # noqa: F401,F403 24 | else: 25 | raise ValueError("Unknown backend: " + str(_BACKEND)) 26 | -------------------------------------------------------------------------------- /output/esmble_jiao.py: -------------------------------------------------------------------------------- 1 | #分别尝试求交集与并集来进行融合 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | result1 = pd.read_csv("test_upload_42.csv") 6 | result2 = pd.read_csv("test_upload_48.csv") 7 | 8 | # 1.交集 9 | for i in tqdm(range(result1.shape[0])): 10 | new_xmin = max(result1['X1'][i], result2['X1'][i]) 11 | new_xmax = min(result1['X2'][i], result2['X2'][i]) 12 | new_ymin = max(result1['Y1'][i], result2['Y1'][i]) 13 | new_ymax = min(result1['Y4'][i], result2['Y4'][i]) 14 | 15 | result1['X1'][i] = new_xmin 16 | result1['X2'][i] = new_xmax 17 | result1['X3'][i] = new_xmax 18 | result1['X4'][i] = new_xmin 19 | 20 | result1['Y1'][i] = new_ymin 21 | result1['Y2'][i] = new_ymin 22 | result1['Y3'][i] = new_ymax 23 | result1['Y4'][i] = new_ymax 24 | 25 | result1.to_csv('result_upload_jiao.csv', index=0) -------------------------------------------------------------------------------- /output/esmble.py: -------------------------------------------------------------------------------- 1 | #分别尝试求交集与并集来进行融合 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | result1 = pd.read_csv("test_upload_42.csv") 6 | result2 = pd.read_csv("test_upload_48.csv") 7 | # 并集 8 | for i in tqdm(range(result1.shape[0])): 9 | # print(i) 10 | new_xmin = min(result1['X1'][i], result2['X1'][i]) 11 | new_xmax = max(result1['X2'][i], result2['X2'][i]) 12 | new_ymin = min(result1['Y1'][i], result2['Y1'][i]) 13 | new_ymax = max(result1['Y4'][i], result2['Y4'][i]) 14 | 15 | result1['X1'][i] = new_xmin 16 | result1['X2'][i] = new_xmax 17 | result1['X3'][i] = new_xmax 18 | result1['X4'][i] = new_xmin 19 | 20 | result1['Y1'][i] = new_ymin 21 | result1['Y2'][i] = new_ymin 22 | result1['Y3'][i] = new_ymax 23 | result1['Y4'][i] = new_ymax 24 | 25 | result1.to_csv('result_upload_42_48.csv', index=0) 26 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | 基于RetinaNet检测器作为主干网络检测器来做的,预处理部分主要是对图像作了增强,从原图中扣取出400*400的图进行训练,更容易准确检测到目标。 2 | 测试集的话则根据之前比较好的测试结果来进行扣取测试图像。同时可以对多个检测模型的检测结果进行融合,取检测框的交集或者取并集具体根据自己的任务需求而定。 3 | 然后就是细致的调参过程了。 4 | 5 | 对于二阶检测算法不是太熟,尝试了mask RCNN,faster RCNN,可能是没有调好,线上只能调到0.8+,着实是一种遗憾,还是需要更加深入不断的学习。 6 | 7 | 最后其实关于数据增强方面有尝试随机裁剪原图像,效果应该可以更好,只是策略训练的太晚了,同时感觉对原图加一些图像增强及相关预处理效果可以更好。 8 | 9 | 总之,参加比赛嘛,一次学习和实践的过程,也着实在这个过程里学习到了不少的东西。分享给大家,交流学习一下 10 | 11 | 这里简单的补充下程序的使用方法: 12 | 首先运行crop_train.py文件,对原训练集进行裁剪,生成了data_train.csv标签文件; 13 | 然后运行data路径下的划分验证集和训练集的脚本,然后可以直接运行 14 | Python train.py csv data/train_label.csv data/classes.csv --val_annotations data/val_label.csv. 一般即可正常开始训练。 15 | 训练完成后,需要运行crop_test.py对测试集进行裁剪,然后执行Python convert_model.py model1.h5 model2.h5,完成模型的转化才可以用于测试。 16 | 最后即可执行test文件,生成结果,并对裁剪后的生成结果进行还原(通过test_crop_result.py脚本),得到最终检测结果。 17 | output路径下的ensemble脚本可以对多个检测结果进行融合。 18 | 调参方面需要自己进一步研究看下源码即可。如还存在问题,欢迎一起交流学习 19 | 20 | -------------------------------------------------------------------------------- /data/convert.py: -------------------------------------------------------------------------------- 1 | """这个脚本用于提取原图上的训练坐标x1 y1 x2 y2形式 2 | train_label_0505.csv 这个文件是提出了两张错误标注后的train_label_fix.csv文件 3 | 生成 train_only.csv 和 test_only.csv 4 | """ 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.model_selection import train_test_split 9 | 10 | #raw labels 11 | raw_labels = pd.read_csv("train_label_0505.csv",header=None)[1:].values 12 | 13 | with open("train_only.csv","w") as f: 14 | for value in raw_labels: 15 | filename = os.getcwd() + "/Train_fix/" +value[0] 16 | label = str(value[1])+','+str(value[2])+','+str(value[5])+','+str(value[6]) 17 | type = value[9]#-1 18 | 19 | f.write(filename+","+label + ','+type+"\n") 20 | 21 | #raw labels 22 | raw_labels = pd.read_csv("submit_sample_fix.csv",header=None)[1:].values 23 | 24 | with open("test_only.csv","w") as f: 25 | for value in raw_labels: 26 | filename = os.getcwd() + "/Test_fix/" +value[0] 27 | 28 | f.write(filename+'\n') 29 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import keras_retinanet.losses 2 | import keras 3 | 4 | import numpy as np 5 | 6 | import pytest 7 | 8 | 9 | def test_smooth_l1(): 10 | regression = np.array([ 11 | [ 12 | [0, 0, 0, 0], 13 | [0, 0, 0, 0], 14 | [0, 0, 0, 0], 15 | [0, 0, 0, 0], 16 | ] 17 | ], dtype=keras.backend.floatx()) 18 | regression = keras.backend.variable(regression) 19 | 20 | regression_target = np.array([ 21 | [ 22 | [0, 0, 0, 1, 1], 23 | [0, 0, 1, 0, 1], 24 | [0, 0, 0.05, 0, 1], 25 | [0, 0, 1, 0, 0], 26 | ] 27 | ], dtype=keras.backend.floatx()) 28 | regression_target = keras.backend.variable(regression_target) 29 | 30 | loss = keras_retinanet.losses.smooth_l1()(regression_target, regression) 31 | loss = keras.backend.eval(loss) 32 | 33 | assert loss == pytest.approx((((1 - 0.5 / 9) * 2 + (0.5 * 9 * 0.05 ** 2)) / 3)) 34 | -------------------------------------------------------------------------------- /test_crop_result.py: -------------------------------------------------------------------------------- 1 | #这个脚本是为了把crop之后网络的得到的结果转换到原图上,可以提交 2 | 3 | import pandas as pd 4 | import shutil 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | from matplotlib import pyplot as plt 9 | import cv2 10 | 11 | result = pd.read_csv("../result_crop.csv", index_col='filename') #检测结果 12 | 13 | num = 0 14 | with open("detla_test.txt", "r") as f: 15 | for line in f.readlines(): 16 | pic_name = line.split(' ')[0]#去掉列表中每一个元素的换行符 17 | detla_x = int(line.split(' ')[1]) 18 | detla_y = int(line.split(' ')[2]) 19 | 20 | r = result.loc[pic_name] 21 | r.X1 += detla_x 22 | r.X2 += detla_x 23 | r.X3 += detla_x 24 | r.X4 += detla_x 25 | 26 | r.Y1 += detla_y 27 | r.Y2 += detla_y 28 | r.Y3 += detla_y 29 | r.Y4 += detla_y 30 | 31 | print("这是第", num, "个, 这是类:", r.type) 32 | num += 1 33 | 34 | result.to_csv("final_submission.csv") 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /output/test_crop_result.py: -------------------------------------------------------------------------------- 1 | #这个脚本是为了把crop之后网络的得到的结果转换到原图上,可以提交 2 | 3 | import pandas as pd 4 | import shutil 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | from matplotlib import pyplot as plt 9 | import cv2 10 | 11 | result = pd.read_csv("./upload/test_upload_07.csv", index_col='filename') #检测结果 12 | 13 | num = 0 14 | with open("detla_test.txt", "r") as f: 15 | for line in f.readlines(): 16 | pic_name = line.split(' ')[0]#去掉列表中每一个元素的换行符 17 | detla_x = int(line.split(' ')[1]) 18 | detla_y = int(line.split(' ')[2]) 19 | 20 | r = result.loc[pic_name] 21 | r.X1 += detla_x 22 | r.X2 += detla_x 23 | r.X3 += detla_x 24 | r.X4 += detla_x 25 | 26 | r.Y1 += detla_y 27 | r.Y2 += detla_y 28 | r.Y3 += detla_y 29 | r.Y4 += detla_y 30 | 31 | print("这是第", num, "个, 这是类:", r.type) 32 | num += 1 33 | 34 | result.to_csv("final_submission.csv") 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /keras_retinanet/utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | def freeze(model): 19 | """ Set all layers in a model to non-trainable. 20 | 21 | The weights for these layers will not be updated during training. 22 | 23 | This function modifies the given model in-place, 24 | but it also returns the modified model to allow easy chaining with other functions. 25 | """ 26 | for layer in model.layers: 27 | layer.trainable = False 28 | return model 29 | -------------------------------------------------------------------------------- /keras_retinanet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | 19 | import numpy as np 20 | import math 21 | 22 | 23 | class PriorProbability(keras.initializers.Initializer): 24 | """ Apply a prior probability to the weights. 25 | """ 26 | 27 | def __init__(self, probability=0.01): 28 | self.probability = probability 29 | 30 | def get_config(self): 31 | return { 32 | 'probability': self.probability 33 | } 34 | 35 | def __call__(self, shape, dtype=None): 36 | # set bias to -log((1 - p)/p) for foreground 37 | result = np.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability) 38 | 39 | return result 40 | -------------------------------------------------------------------------------- /keras_retinanet/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import configparser 18 | import numpy as np 19 | import keras 20 | from ..utils.anchors import AnchorParameters 21 | 22 | 23 | def read_config_file(config_path): 24 | config = configparser.ConfigParser() 25 | config.read(config_path) 26 | 27 | return config 28 | 29 | 30 | def parse_anchor_parameters(config): 31 | ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx()) 32 | scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx()) 33 | sizes = list(map(int, config['anchor_parameters']['sizes'].split(' '))) 34 | strides = list(map(int, config['anchor_parameters']['strides'].split(' '))) 35 | 36 | return AnchorParameters(sizes, strides, ratios, scales) 37 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/common.py: -------------------------------------------------------------------------------- 1 | import keras.callbacks 2 | 3 | 4 | class RedirectModel(keras.callbacks.Callback): 5 | """Callback which wraps another callback, but executed on a different model. 6 | 7 | ```python 8 | model = keras.models.load_model('model.h5') 9 | model_checkpoint = ModelCheckpoint(filepath='snapshot.h5') 10 | parallel_model = multi_gpu_model(model, gpus=2) 11 | parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)]) 12 | ``` 13 | 14 | Args 15 | callback : callback to wrap. 16 | model : model to use when executing callbacks. 17 | """ 18 | 19 | def __init__(self, 20 | callback, 21 | model): 22 | super(RedirectModel, self).__init__() 23 | 24 | self.callback = callback 25 | self.redirect_model = model 26 | 27 | def on_epoch_begin(self, epoch, logs=None): 28 | self.callback.on_epoch_begin(epoch, logs=logs) 29 | 30 | def on_epoch_end(self, epoch, logs=None): 31 | self.callback.on_epoch_end(epoch, logs=logs) 32 | 33 | def on_batch_begin(self, batch, logs=None): 34 | self.callback.on_batch_begin(batch, logs=logs) 35 | 36 | def on_batch_end(self, batch, logs=None): 37 | self.callback.on_batch_end(batch, logs=logs) 38 | 39 | def on_train_begin(self, logs=None): 40 | # overwrite the model with our custom model 41 | self.callback.set_model(self.redirect_model) 42 | 43 | self.callback.on_train_begin(logs=logs) 44 | 45 | def on_train_end(self, logs=None): 46 | self.callback.on_train_end(logs=logs) 47 | -------------------------------------------------------------------------------- /tests/models/test_densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.densenet import DenseNetBackbone 23 | 24 | parameters = ['densenet121'] 25 | 26 | 27 | @pytest.mark.parametrize("backbone", parameters) 28 | def test_backbone(backbone): 29 | # ignore warnings in this test 30 | warnings.simplefilter('ignore') 31 | 32 | num_classes = 10 33 | 34 | inputs = np.zeros((1, 200, 400, 3), dtype=np.float32) 35 | targets = [np.zeros((1, 14814, 5), dtype=np.float32), np.zeros((1, 14814, num_classes + 1))] 36 | 37 | inp = keras.layers.Input(inputs[0].shape) 38 | 39 | densenet_backbone = DenseNetBackbone(backbone) 40 | model = densenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 41 | model.summary() 42 | 43 | # compile model 44 | model.compile( 45 | loss={ 46 | 'regression': losses.smooth_l1(), 47 | 'classification': losses.focal() 48 | }, 49 | optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)) 50 | 51 | model.fit(inputs, targets, batch_size=1) 52 | -------------------------------------------------------------------------------- /keras_retinanet/utils/keras_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import print_function 18 | 19 | import keras 20 | import sys 21 | 22 | minimum_keras_version = 2, 2, 4 23 | 24 | 25 | def keras_version(): 26 | """ Get the Keras version. 27 | 28 | Returns 29 | tuple of (major, minor, patch). 30 | """ 31 | return tuple(map(int, keras.__version__.split('.'))) 32 | 33 | 34 | def keras_version_ok(): 35 | """ Check if the current Keras version is higher than the minimum version. 36 | """ 37 | return keras_version() >= minimum_keras_version 38 | 39 | 40 | def assert_keras_version(): 41 | """ Assert that the Keras version is up to date. 42 | """ 43 | detected = keras.__version__ 44 | required = '.'.join(map(str, minimum_keras_version)) 45 | assert(keras_version() >= minimum_keras_version), 'You are using keras version {}. The minimum required version is {}.'.format(detected, required) 46 | 47 | 48 | def check_keras_version(): 49 | """ Check that the Keras version is up to date. If it isn't, print an error message and exit the script. 50 | """ 51 | try: 52 | assert_keras_version() 53 | except AssertionError as e: 54 | print(e, file=sys.stderr) 55 | sys.exit(1) 56 | -------------------------------------------------------------------------------- /data/data_count_and_split.py: -------------------------------------------------------------------------------- 1 | """当前脚本用于将裁剪后的训练集划分出最后的训练和验证集""" 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | import shutil 7 | #raw labels 8 | """这里的data_train.csv文件是抽取了训练文件中的X1 Y1 X3 Y3以及filename和type之后的文件, 9 | 需要剔除两个label有问题的文件,当然也可以继续剔除其他觉得有问题的label,注意,这里的文件名及框的坐标都是裁剪后的命名及框结果""" 10 | raw_labels = pd.read_csv("data_train.csv",header=None).values 11 | print(raw_labels.shape)#(20238, 6) 12 | filename = [] 13 | label = [] 14 | for value in raw_labels: 15 | filename.append(value[0]) 16 | label.append(value[5]) 17 | 18 | # import collections 19 | # print(collections.Counter(label)) 20 | #Counter({'17': 1022, '16': 1019, '18': 1017, '5': 1017, '11': 1014, '10': 1013, '14': 1013, '9': 1013, '20': 1013, '1': 1012, '13': 1011, '6': 1011, '19': 1011, '4': 1010, '15': 1009, '3': 1009, '7': 1008, '2': 1008, '8': 1007, '12': 1002}) 21 | 22 | X_train, X_val, y_train, y_val = train_test_split(filename, label, test_size=0.2, random_state=42, shuffle=True, stratify=label) #16191 23 | # print((X_train))'f6a2a4aea44a48a09d253538a310af71.jpg' 24 | 25 | with open("train_label.csv","w") as f, open("val_label.csv", "w") as g: 26 | for value in raw_labels: 27 | filename = os.getcwd() + "/train/" + value[0].split('/')[-1] 28 | label = str(value[1]) + ',' + str(value[2]) + ',' + str(value[3]) + ',' + str(value[4]) 29 | type = str(value[5]) 30 | 31 | if value[0] in X_train: 32 | dstname1 = os.getcwd() + "/train_pic/" + value[0].split('/')[-1] 33 | shutil.copy(filename,dstname1) 34 | f.write(dstname1 + "," + label + ',' + type + "\n") 35 | 36 | else: 37 | 38 | dstname2 = os.getcwd() + "/val_pic/" + value[0].split('/')[-1] 39 | shutil.copy(filename, dstname2) 40 | g.write(dstname2 + "," + label + ',' + type + "\n") 41 | 42 | 43 | -------------------------------------------------------------------------------- /keras_retinanet/utils/compute_overlap.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | 13 | def compute_overlap( 14 | np.ndarray[double, ndim=2] boxes, 15 | np.ndarray[double, ndim=2] query_boxes 16 | ): 17 | """ 18 | Args 19 | a: (N, 4) ndarray of float 20 | b: (K, 4) ndarray of float 21 | 22 | Returns 23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 24 | """ 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef unsigned int K = query_boxes.shape[0] 27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) 28 | cdef double iw, ih, box_area 29 | cdef double ua 30 | cdef unsigned int k, n 31 | for k in range(K): 32 | box_area = ( 33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 35 | ) 36 | for n in range(N): 37 | iw = ( 38 | min(boxes[n, 2], query_boxes[k, 2]) - 39 | max(boxes[n, 0], query_boxes[k, 0]) + 1 40 | ) 41 | if iw > 0: 42 | ih = ( 43 | min(boxes[n, 3], query_boxes[k, 3]) - 44 | max(boxes[n, 1], query_boxes[k, 1]) + 1 45 | ) 46 | if ih > 0: 47 | ua = np.float64( 48 | (boxes[n, 2] - boxes[n, 0] + 1) * 49 | (boxes[n, 3] - boxes[n, 1] + 1) + 50 | box_area - iw * ih 51 | ) 52 | overlaps[n, k] = iw * ih / ua 53 | return overlaps 54 | -------------------------------------------------------------------------------- /crop_test.py: -------------------------------------------------------------------------------- 1 | """当前脚本用于裁剪测试图像 2 | 以(之前线上检测到框的)目标中心为原点 上下左右各扩展200pixel进行切割,对于不能检测到的图像只能通过再手工标注或者剔除标记为0""" 3 | 4 | import pandas as pd 5 | import shutil 6 | from tqdm import tqdm 7 | import numpy as np 8 | from PIL import Image 9 | from matplotlib import pyplot as plt 10 | import cv2 11 | 12 | df = pd.read_csv("./output/result_upload_42_48_007.csv") #初次检测结果 13 | t = open('detla_test.txt', 'w') 14 | # print(df.shape)#(20256, 10) 15 | for i in tqdm(range(df.shape[0])): 16 | pic_name = df.filename[i] 17 | if(df.type[i]==0): 18 | print("pass this pic") 19 | continue 20 | # re_pic_name = str('%06d' % i) + '.jpg' #新的名字补零 21 | 22 | img = Image.open('./data/Test_fix/' + pic_name) #读取原图 23 | img_np = np.array(img) 24 | (height, width, deepth) = img_np.shape #原图的信息 25 | 26 | x_center = int((df.X2[i] + df.X1[i])/2) #标注框的中心点 27 | y_center = int((df.Y1[i] + df.Y3[i])/2) 28 | 29 | crop_x_left = x_center - 200 #四个剪切的边界 30 | crop_x_right = x_center + 200 31 | crop_y_up = y_center - 200 32 | crop_y_down = y_center + 200 33 | 34 | if crop_x_left < 0: #防止越界 35 | crop_x_left = 0 36 | if crop_y_up < 0: 37 | crop_y_up = 0 38 | if crop_x_right > width: 39 | crop_x_right = width 40 | if crop_y_down > height: 41 | crop_y_down = height 42 | 43 | detla_x = crop_x_left #记录x, y的偏移量,为了最后的坐标输出 44 | detla_y = crop_y_up 45 | 46 | t.write(pic_name + ' ' + str(detla_x) + ' ' + str(detla_y) + '\n') #记录偏移量 47 | new_img = img.crop((crop_x_left,crop_y_up,crop_x_right ,crop_y_down)) 48 | # new_img.show() 49 | new_img.save('./data/test/' + pic_name) #保存剪裁的图片 50 | 51 | t.close() 52 | -------------------------------------------------------------------------------- /tests/models/test_mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.mobilenet import MobileNetBackbone 23 | 24 | alphas = ['1.0'] 25 | parameters = [] 26 | 27 | for backbone in MobileNetBackbone.allowed_backbones: 28 | for alpha in alphas: 29 | parameters.append((backbone, alpha)) 30 | 31 | 32 | @pytest.mark.parametrize("backbone, alpha", parameters) 33 | def test_backbone(backbone, alpha): 34 | # ignore warnings in this test 35 | warnings.simplefilter('ignore') 36 | 37 | num_classes = 10 38 | 39 | inputs = np.zeros((1, 1024, 363, 3), dtype=np.float32) 40 | targets = [np.zeros((1, 68760, 5), dtype=np.float32), np.zeros((1, 68760, num_classes + 1))] 41 | 42 | inp = keras.layers.Input(inputs[0].shape) 43 | 44 | mobilenet_backbone = MobileNetBackbone(backbone='{}_{}'.format(backbone, format(alpha))) 45 | training_model = mobilenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 46 | training_model.summary() 47 | 48 | # compile model 49 | training_model.compile( 50 | loss={ 51 | 'regression': losses.smooth_l1(), 52 | 'classification': losses.focal() 53 | }, 54 | optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)) 55 | 56 | training_model.fit(inputs, targets, batch_size=1) 57 | -------------------------------------------------------------------------------- /crop_train.py: -------------------------------------------------------------------------------- 1 | """当前脚本用于裁剪训练图像 2 | 以目标中心为原点 上下左右各扩展200pixel进行切割 3 | """ 4 | import pandas as pd 5 | import shutil 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | from PIL import Image 10 | from matplotlib import pyplot as plt 11 | import cv2 12 | 13 | """这里的train_only.csv文件是抽取了原训练文件train_label_fix.csv中的X1 Y1 X3 Y3以及filename和type之后的文件,剔除清洗掉的明显错误标注的两张图片, 14 | 如果发现更多错误标注,剔除了更好""" 15 | df = pd.read_csv("./data/train_only.csv",header=None) 16 | t = open('./data/detla.txt', 'w') 17 | 18 | file = open('./data/data_train.csv', "w") 19 | # print(df.shape) #(20239, 6) 20 | for i in tqdm(range(df.shape[0])): 21 | pic_name = df[0][i] 22 | # print(pic_name) 23 | re_pic_name = str('%06d' % i) + '.jpg' #新的名字补零 24 | 25 | # img = Image.open(pic_name) #读取原图 26 | # img_np = np.array(img) 27 | # (height, width, deepth) = img_np.shape #原图的信息 28 | height,width = 1800,3200 29 | x_center = int((df[1][i] + df[3][i])/2) #标注框的中心点 30 | y_center = int((df[2][i] + df[4][i])/2) 31 | 32 | crop_x_left = x_center - 200 #四个剪切的边界 33 | crop_x_right = x_center + 200 34 | crop_y_up = y_center - 200 35 | crop_y_down = y_center + 200 36 | 37 | if crop_x_left < 0: #防止越界 38 | crop_x_left = 0 39 | if crop_y_up < 0: 40 | crop_y_up = 0 41 | if crop_x_right > width: 42 | crop_x_right = width 43 | if crop_y_down > height: 44 | crop_y_down = height 45 | 46 | detla_x = crop_x_left #记录x, y的偏移量,为了最后的坐标输出 47 | detla_y = crop_y_up 48 | 49 | t.write(pic_name + ' ' + str(detla_x) + ' ' + str(detla_y) + '\n') #记录偏移量 50 | 51 | # new_img = img.crop((crop_x_left,crop_y_up,crop_x_right ,crop_y_down)) 52 | # # new_img.show() 53 | # new_img.save('./data/train/' + re_pic_name) #保存剪裁的图片 54 | 55 | xmin = df[1][i] - detla_x 56 | ymin = df[2][i] - detla_y 57 | xmax = df[3][i] - detla_x 58 | ymax = df[4][i] - detla_y 59 | file_name = os.getcwd() + '/data/train/' + re_pic_name 60 | # print(type(xmin)) 61 | file.write(file_name+','+str(xmin)+','+str(ymin)+','+str(xmax)+','+ 62 | str(ymax)+','+str(df[5][i])+'\n') 63 | # img_show = cv2.imread('./data/train/' + re_pic_name) 64 | # cv2.rectangle(img_show, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) 65 | # plt.imshow(img_show) 66 | # plt.show() 67 | 68 | t.close() 69 | 70 | 71 | -------------------------------------------------------------------------------- /tests/bin/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras_retinanet.bin.train 18 | import keras.backend 19 | 20 | import warnings 21 | 22 | import pytest 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def clear_session(): 27 | # run before test (do nothing) 28 | yield 29 | # run after test, clear keras session 30 | keras.backend.clear_session() 31 | 32 | 33 | def test_coco(): 34 | # ignore warnings in this test 35 | warnings.simplefilter('ignore') 36 | 37 | # run training / evaluation 38 | keras_retinanet.bin.train.main([ 39 | '--epochs=1', 40 | '--steps=1', 41 | '--no-weights', 42 | '--no-snapshots', 43 | 'coco', 44 | 'tests/test-data/coco', 45 | ]) 46 | 47 | 48 | def test_pascal(): 49 | # ignore warnings in this test 50 | warnings.simplefilter('ignore') 51 | 52 | # run training / evaluation 53 | keras_retinanet.bin.train.main([ 54 | '--epochs=1', 55 | '--steps=1', 56 | '--no-weights', 57 | '--no-snapshots', 58 | 'pascal', 59 | 'tests/test-data/pascal', 60 | ]) 61 | 62 | 63 | def test_csv(): 64 | # ignore warnings in this test 65 | warnings.simplefilter('ignore') 66 | 67 | # run training / evaluation 68 | keras_retinanet.bin.train.main([ 69 | '--epochs=1', 70 | '--steps=1', 71 | '--no-weights', 72 | '--no-snapshots', 73 | 'csv', 74 | 'tests/test-data/csv/annotations.csv', 75 | 'tests/test-data/csv/classes.csv', 76 | ]) 77 | 78 | 79 | def test_vgg(): 80 | # ignore warnings in this test 81 | warnings.simplefilter('ignore') 82 | 83 | # run training / evaluation 84 | keras_retinanet.bin.train.main([ 85 | '--backbone=vgg16', 86 | '--epochs=1', 87 | '--steps=1', 88 | '--no-weights', 89 | '--no-snapshots', 90 | '--freeze-backbone', 91 | 'coco', 92 | 'tests/test-data/coco', 93 | ]) 94 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | #这个脚本用于将训练生成的模型转换成可以进行测试的模型 2 | import argparse 3 | import os 4 | import sys 5 | 6 | import keras 7 | import tensorflow as tf 8 | 9 | # Allow relative imports when being executed as script. 10 | if __name__ == "__main__" and __package__ is None: 11 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 12 | import keras_retinanet.bin # noqa: F401 13 | __package__ = "keras_retinanet.bin" 14 | 15 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 16 | from .. import models 17 | from ..utils.config import read_config_file, parse_anchor_parameters 18 | 19 | 20 | def get_session(): 21 | """ Construct a modified tf session. 22 | """ 23 | config = tf.ConfigProto() 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 25 | return tf.Session(config=config) 26 | 27 | 28 | def parse_args(args): 29 | parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.') 30 | 31 | parser.add_argument('model_in', help='The model to convert.',default='./snapshots/resnet50_pascal_40.h5') 32 | parser.add_argument('model_out', help='Path to save the converted model to.',default='./snapshots/model/resnet50_40.h5') 33 | parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50') 34 | parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false') 35 | parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') 36 | parser.add_argument('--config', help='Path to a configuration parameters .ini file.') 37 | 38 | return parser.parse_args(args) 39 | 40 | 41 | def main(args=None): 42 | # parse arguments 43 | if args is None: 44 | args = sys.argv[1:] 45 | args = parse_args(args) 46 | 47 | # Set modified tf session to avoid using the GPUs 48 | keras.backend.tensorflow_backend.set_session(get_session()) 49 | 50 | # optionally load config parameters 51 | anchor_parameters = None 52 | if args.config: 53 | args.config = read_config_file(args.config) 54 | if 'anchor_parameters' in args.config: 55 | anchor_parameters = parse_anchor_parameters(args.config) 56 | 57 | # load the model 58 | model = models.load_model(args.model_in, backbone_name=args.backbone) 59 | 60 | # check if this is indeed a training model 61 | models.check_training_model(model) 62 | 63 | # convert the model 64 | model = models.convert_model(model, nms=args.nms, class_specific_filter=args.class_specific_filter, anchor_params=anchor_parameters) 65 | 66 | # save model 67 | model.save(args.model_out) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /crop_train_random.py: -------------------------------------------------------------------------------- 1 | #训练集图像随机裁剪脚本,保证目标在裁剪后的图像上即可,从而进行数据增强,可增强多倍数据 2 | import pandas as pd 3 | import shutil 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | from matplotlib import pyplot as plt 9 | import cv2 10 | import random 11 | 12 | df = pd.read_csv("./data/train_only.csv",header=None) 13 | # t = open('./data/detla.txt', 'w') 14 | 15 | file = open('./data/data_train.csv', "w") 16 | print(df.shape) #(20239, 6) 17 | 18 | for j in range(5): #扩充5倍数据量 19 | for i in tqdm(range(df.shape[0])): 20 | pic_name = df[0][i] 21 | # print(pic_name) 22 | re_pic_name = str(i) + '_' + str(j) + '.jpg' #新的名字补零 23 | 24 | img = Image.open(pic_name) #读取原图 25 | img_np = np.array(img) 26 | (height, width, deepth) = img_np.shape #原图的信息 27 | # height,width = 1800,3200 28 | x_center = int((df[1][i] + df[3][i])/2) #标注框的中心点 29 | y_center = int((df[2][i] + df[4][i])/2) 30 | 31 | x_center += random.randint(-80, 80) #中心点加随机数 32 | y_center += random.randint(-80, 80) 33 | if x_center < 0: #防止越界 34 | x_center = 0 35 | if x_center > width: 36 | x_center = width 37 | if y_center < 0: 38 | y_center = 0 39 | if y_center > height: 40 | y_center = height 41 | 42 | crop_x_left = x_center - 200 #四个剪切的边界 43 | crop_x_right = x_center + 200 44 | crop_y_up = y_center - 200 45 | crop_y_down = y_center + 200 46 | 47 | if crop_x_left < 0: #防止越界 48 | crop_x_left = 0 49 | if crop_y_up < 0: 50 | crop_y_up = 0 51 | if crop_x_right > width: 52 | crop_x_right = width 53 | if crop_y_down > height: 54 | crop_y_down = height 55 | 56 | detla_x = crop_x_left #记录x, y的偏移量,为了最后的坐标输出 57 | detla_y = crop_y_up 58 | 59 | # t.write(pic_name + ' ' + str(detla_x) + ' ' + str(detla_y) + '\n') #记录偏移量 60 | 61 | new_img = img.crop((crop_x_left,crop_y_up,crop_x_right ,crop_y_down)) #截图 62 | # new_img.show() 63 | new_img.save('./data/train/' + re_pic_name) #保存剪裁的图片 64 | 65 | xmin = df[1][i] - detla_x 66 | ymin = df[2][i] - detla_y 67 | xmax = df[3][i] - detla_x 68 | ymax = df[4][i] - detla_y 69 | file_name = os.getcwd() + '/data/train/' + re_pic_name 70 | # print(type(xmin)) 71 | file.write(file_name+','+str(xmin)+','+str(ymin)+','+str(xmax)+','+ 72 | str(ymax)+','+str(df[5][i])+'\n') 73 | # img_show = cv2.imread('./data/train/' + re_pic_name) 74 | # cv2.rectangle(img_show, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) 75 | # plt.imshow(img_show) 76 | # plt.show() 77 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from ..utils.coco_eval import evaluate_coco 19 | 20 | 21 | class CocoEval(keras.callbacks.Callback): 22 | """ Performs COCO evaluation on each epoch. 23 | """ 24 | def __init__(self, generator, tensorboard=None, threshold=0.05): 25 | """ CocoEval callback intializer. 26 | 27 | Args 28 | generator : The generator used for creating validation data. 29 | tensorboard : If given, the results will be written to tensorboard. 30 | threshold : The score threshold to use. 31 | """ 32 | self.generator = generator 33 | self.threshold = threshold 34 | self.tensorboard = tensorboard 35 | 36 | super(CocoEval, self).__init__() 37 | 38 | def on_epoch_end(self, epoch, logs=None): 39 | logs = logs or {} 40 | 41 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 42 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 43 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 44 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 45 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 46 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 47 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 48 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 49 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 50 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 51 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 52 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 53 | coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold) 54 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None: 55 | import tensorflow as tf 56 | summary = tf.Summary() 57 | for index, result in enumerate(coco_eval_stats): 58 | summary_value = summary.value.add() 59 | summary_value.simple_value = result 60 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index]) 61 | self.tensorboard.writer.add_summary(summary, epoch) 62 | logs[coco_tag[index]] = result 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """测试生成提交检测结果文件,之后要对先前的采集进行还原以及其他后处理 2 | 主要更改测试采用的模型及生成csv文件名即可 3 | """ 4 | import keras 5 | 6 | # import keras_retinanet 7 | from keras_retinanet import models 8 | from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image 9 | from keras_retinanet.utils.visualization import draw_box, draw_caption 10 | from keras_retinanet.utils.colors import label_color 11 | import pandas as pd 12 | # import miscellaneous modules 13 | import matplotlib.pyplot as plt 14 | import cv2 15 | import os 16 | import numpy as np 17 | import time 18 | 19 | # set tf backend to allow memory to grow, instead of claiming everything 20 | import tensorflow as tf 21 | 22 | def get_session(): 23 | config = tf.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | return tf.Session(config=config) 26 | 27 | # use this environment flag to change which GPU to use 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 29 | 30 | # set the modified tf session as backend in keras 31 | keras.backend.tensorflow_backend.set_session(get_session()) 32 | 33 | # adjust this to point to your downloaded/trained model 34 | 35 | model_path = os.path.join('.','models', 'resnet101_11.h5') #训练的模型 36 | # load retinanet model 37 | model = models.load_model(model_path, backbone_name='resnet101') 38 | 39 | # if the model is not converted to an inference model, use the line below 40 | #model = models.convert_model(model) 41 | # print(model.summary()) 42 | 43 | # load label to names mapping for visualization purposes 44 | labels_to_names = {0: '1',1: '2',2: '3',3: '4',4: '5',5: '6',6: '7',7: '8',8: '9',9: '10',10: '11',11: '12',12: '13',13: '14',14: '15',15: '16',16: '17',17: '18',18: '19',19: '20'} 45 | # load image 46 | filename = './output/result_upload_42_48_007.csv' 47 | data = pd.read_csv(filename,header=None,index_col=None)[1:].values 48 | # print(data.shape) 49 | with open("./output/upload/test_upload_11.csv","w") as f: 50 | f.write('filename' + "," + 'X1'+","+'Y1'+","+'X2'+","+'Y2'+","+'X3'+","+'Y3'+","+'X4'+","+'Y4'+","+"type" + "\n") 51 | i = 0 52 | for value in data: 53 | i += 1 54 | print(i) 55 | if(value[9]=='0'): 56 | print("pass this pic") 57 | continue 58 | image = read_image_bgr('./data/test/'+value[0]) 59 | # copy to draw on 60 | draw = image.copy() 61 | draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB) 62 | 63 | # preprocess image for network 64 | image = preprocess_image(image) 65 | file = value[0] 66 | 67 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 68 | 69 | box, score, label = boxes[0][0],scores[0][0],labels[0][0] 70 | 71 | b = box 72 | if(score == -1): 73 | b = [0,0,0,0] 74 | label = '0' 75 | else: 76 | label = labels_to_names[label] 77 | bbox = str(b[0])+","+str(b[1])+","+str(b[2])+","+str(b[1])+","+\ 78 | str(b[2])+","+str(b[3])+","+str(b[0])+","+str(b[3]) 79 | print(bbox,score,label) 80 | f.write(file+","+bbox +","+ label+"\n") 81 | 82 | 83 | print("write done") 84 | -------------------------------------------------------------------------------- /keras_retinanet/utils/colors.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def label_color(label): 5 | """ Return a color from a set of predefined colors. Contains 80 colors in total. 6 | 7 | Args 8 | label: The label to get the color for. 9 | 10 | Returns 11 | A list of three values representing a RGB color. 12 | 13 | If no color is defined for a certain label, the color green is returned and a warning is printed. 14 | """ 15 | if label < len(colors): 16 | return colors[label] 17 | else: 18 | warnings.warn('Label {} has no color, returning default.'.format(label)) 19 | return (0, 255, 0) 20 | 21 | 22 | """ 23 | Generated using: 24 | 25 | ``` 26 | colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)] 27 | shuffle(colors) 28 | pprint(colors) 29 | ``` 30 | """ 31 | colors = [ 32 | [31 , 0 , 255] , 33 | [0 , 159 , 255] , 34 | [255 , 95 , 0] , 35 | [255 , 19 , 0] , 36 | [255 , 0 , 0] , 37 | [255 , 38 , 0] , 38 | [0 , 255 , 25] , 39 | [255 , 0 , 133] , 40 | [255 , 172 , 0] , 41 | [108 , 0 , 255] , 42 | [0 , 82 , 255] , 43 | [0 , 255 , 6] , 44 | [255 , 0 , 152] , 45 | [223 , 0 , 255] , 46 | [12 , 0 , 255] , 47 | [0 , 255 , 178] , 48 | [108 , 255 , 0] , 49 | [184 , 0 , 255] , 50 | [255 , 0 , 76] , 51 | [146 , 255 , 0] , 52 | [51 , 0 , 255] , 53 | [0 , 197 , 255] , 54 | [255 , 248 , 0] , 55 | [255 , 0 , 19] , 56 | [255 , 0 , 38] , 57 | [89 , 255 , 0] , 58 | [127 , 255 , 0] , 59 | [255 , 153 , 0] , 60 | [0 , 255 , 255] , 61 | [0 , 255 , 216] , 62 | [0 , 255 , 121] , 63 | [255 , 0 , 248] , 64 | [70 , 0 , 255] , 65 | [0 , 255 , 159] , 66 | [0 , 216 , 255] , 67 | [0 , 6 , 255] , 68 | [0 , 63 , 255] , 69 | [31 , 255 , 0] , 70 | [255 , 57 , 0] , 71 | [255 , 0 , 210] , 72 | [0 , 255 , 102] , 73 | [242 , 255 , 0] , 74 | [255 , 191 , 0] , 75 | [0 , 255 , 63] , 76 | [255 , 0 , 95] , 77 | [146 , 0 , 255] , 78 | [184 , 255 , 0] , 79 | [255 , 114 , 0] , 80 | [0 , 255 , 235] , 81 | [255 , 229 , 0] , 82 | [0 , 178 , 255] , 83 | [255 , 0 , 114] , 84 | [255 , 0 , 57] , 85 | [0 , 140 , 255] , 86 | [0 , 121 , 255] , 87 | [12 , 255 , 0] , 88 | [255 , 210 , 0] , 89 | [0 , 255 , 44] , 90 | [165 , 255 , 0] , 91 | [0 , 25 , 255] , 92 | [0 , 255 , 140] , 93 | [0 , 101 , 255] , 94 | [0 , 255 , 82] , 95 | [223 , 255 , 0] , 96 | [242 , 0 , 255] , 97 | [89 , 0 , 255] , 98 | [165 , 0 , 255] , 99 | [70 , 255 , 0] , 100 | [255 , 0 , 172] , 101 | [255 , 76 , 0] , 102 | [203 , 255 , 0] , 103 | [204 , 0 , 255] , 104 | [255 , 0 , 229] , 105 | [255 , 133 , 0] , 106 | [127 , 0 , 255] , 107 | [0 , 235 , 255] , 108 | [0 , 255 , 197] , 109 | [255 , 0 , 191] , 110 | [0 , 44 , 255] , 111 | [50 , 255 , 0] 112 | ] 113 | -------------------------------------------------------------------------------- /keras_retinanet/bin/convert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Copyright 2017-2018 Fizyr (https://fizyr.com) 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import argparse 20 | import os 21 | import sys 22 | 23 | import keras 24 | import tensorflow as tf 25 | 26 | # Allow relative imports when being executed as script. 27 | if __name__ == "__main__" and __package__ is None: 28 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 29 | import keras_retinanet.bin # noqa: F401 30 | __package__ = "keras_retinanet.bin" 31 | 32 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 33 | from .. import models 34 | from ..utils.config import read_config_file, parse_anchor_parameters 35 | 36 | 37 | def get_session(): 38 | """ Construct a modified tf session. 39 | """ 40 | config = tf.ConfigProto() 41 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 42 | return tf.Session(config=config) 43 | 44 | 45 | def parse_args(args): 46 | parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.') 47 | 48 | parser.add_argument('model_in', help='The model to convert.') 49 | parser.add_argument('model_out', help='Path to save the converted model to.') 50 | parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50') 51 | parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false') 52 | parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') 53 | parser.add_argument('--config', help='Path to a configuration parameters .ini file.') 54 | 55 | return parser.parse_args(args) 56 | 57 | 58 | def main(args=None): 59 | # parse arguments 60 | if args is None: 61 | args = sys.argv[1:] 62 | args = parse_args(args) 63 | 64 | # Set modified tf session to avoid using the GPUs 65 | keras.backend.tensorflow_backend.set_session(get_session()) 66 | 67 | # optionally load config parameters 68 | anchor_parameters = None 69 | if args.config: 70 | args.config = read_config_file(args.config) 71 | if 'anchor_parameters' in args.config: 72 | anchor_parameters = parse_anchor_parameters(args.config) 73 | 74 | # load the model 75 | model = models.load_model(args.model_in, backbone_name=args.backbone) 76 | 77 | # check if this is indeed a training model 78 | models.check_training_model(model) 79 | 80 | # convert the model 81 | model = models.convert_model(model, nms=args.nms, class_specific_filter=args.class_specific_filter, anchor_params=anchor_parameters) 82 | 83 | # save model 84 | model.save(args.model_out) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /keras_retinanet/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from pycocotools.cocoeval import COCOeval 18 | 19 | import keras 20 | import numpy as np 21 | import json 22 | 23 | import progressbar 24 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 25 | 26 | 27 | def evaluate_coco(generator, model, threshold=0.05): 28 | """ Use the pycocotools to evaluate a COCO model on a dataset. 29 | 30 | Args 31 | generator : The generator for generating the evaluation data. 32 | model : The model to evaluate. 33 | threshold : The score threshold to use. 34 | """ 35 | # start collecting results 36 | results = [] 37 | image_ids = [] 38 | for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '): 39 | image = generator.load_image(index) 40 | image = generator.preprocess_image(image) 41 | image, scale = generator.resize_image(image) 42 | 43 | if keras.backend.image_data_format() == 'channels_first': 44 | image = image.transpose((2, 0, 1)) 45 | 46 | # run network 47 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 48 | 49 | # correct boxes for image scale 50 | boxes /= scale 51 | 52 | # change to (x, y, w, h) (MS COCO standard) 53 | boxes[:, :, 2] -= boxes[:, :, 0] 54 | boxes[:, :, 3] -= boxes[:, :, 1] 55 | 56 | # compute predicted labels and scores 57 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 58 | # scores are sorted, so we can break 59 | if score < threshold: 60 | break 61 | 62 | # append detection for each positively labeled class 63 | image_result = { 64 | 'image_id' : generator.image_ids[index], 65 | 'category_id' : generator.label_to_coco_label(label), 66 | 'score' : float(score), 67 | 'bbox' : box.tolist(), 68 | } 69 | 70 | # append detection to results 71 | results.append(image_result) 72 | 73 | # append image to list of processed images 74 | image_ids.append(generator.image_ids[index]) 75 | 76 | if not len(results): 77 | return 78 | 79 | # write output 80 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) 81 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) 82 | 83 | # load results in COCO evaluation tool 84 | coco_true = generator.coco 85 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) 86 | 87 | # run COCO evaluation 88 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 89 | coco_eval.params.imgIds = image_ids 90 | coco_eval.evaluate() 91 | coco_eval.accumulate() 92 | coco_eval.summarize() 93 | return coco_eval.stats 94 | -------------------------------------------------------------------------------- /keras_retinanet/backend/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow 18 | 19 | 20 | def ones(*args, **kwargs): 21 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/ones . 22 | """ 23 | return tensorflow.ones(*args, **kwargs) 24 | 25 | 26 | def transpose(*args, **kwargs): 27 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/transpose . 28 | """ 29 | return tensorflow.transpose(*args, **kwargs) 30 | 31 | 32 | def map_fn(*args, **kwargs): 33 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/map_fn . 34 | """ 35 | return tensorflow.map_fn(*args, **kwargs) 36 | 37 | 38 | def pad(*args, **kwargs): 39 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/pad . 40 | """ 41 | return tensorflow.pad(*args, **kwargs) 42 | 43 | 44 | def top_k(*args, **kwargs): 45 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/top_k . 46 | """ 47 | return tensorflow.nn.top_k(*args, **kwargs) 48 | 49 | 50 | def clip_by_value(*args, **kwargs): 51 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/clip_by_value . 52 | """ 53 | return tensorflow.clip_by_value(*args, **kwargs) 54 | 55 | 56 | def resize_images(images, size, method='bilinear', align_corners=False): 57 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/resize_images . 58 | 59 | Args 60 | method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area'). 61 | """ 62 | methods = { 63 | 'bilinear': tensorflow.image.ResizeMethod.BILINEAR, 64 | 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, 65 | 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC, 66 | 'area' : tensorflow.image.ResizeMethod.AREA, 67 | } 68 | return tensorflow.image.resize_images(images, size, methods[method], align_corners) 69 | 70 | 71 | def non_max_suppression(*args, **kwargs): 72 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/non_max_suppression . 73 | """ 74 | return tensorflow.image.non_max_suppression(*args, **kwargs) 75 | 76 | 77 | def range(*args, **kwargs): 78 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/range . 79 | """ 80 | return tensorflow.range(*args, **kwargs) 81 | 82 | 83 | def scatter_nd(*args, **kwargs): 84 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/scatter_nd . 85 | """ 86 | return tensorflow.scatter_nd(*args, **kwargs) 87 | 88 | 89 | def gather_nd(*args, **kwargs): 90 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/gather_nd . 91 | """ 92 | return tensorflow.gather_nd(*args, **kwargs) 93 | 94 | 95 | def meshgrid(*args, **kwargs): 96 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/meshgrid . 97 | """ 98 | return tensorflow.meshgrid(*args, **kwargs) 99 | 100 | 101 | def where(*args, **kwargs): 102 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/where . 103 | """ 104 | return tensorflow.where(*args, **kwargs) 105 | -------------------------------------------------------------------------------- /keras_retinanet/backend/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras.backend 18 | from .dynamic import meshgrid 19 | 20 | 21 | def bbox_transform_inv(boxes, deltas, mean=None, std=None): 22 | """ Applies deltas (usually regression results) to boxes (usually anchors). 23 | 24 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. 25 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. 26 | 27 | Args 28 | boxes : np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2). 29 | deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. 30 | mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]). 31 | std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). 32 | 33 | Returns 34 | A np.array of the same shape as boxes, but with deltas applied to each box. 35 | The mean and std are used during training to normalize the regression values (networks love normalization). 36 | """ 37 | if mean is None: 38 | mean = [0, 0, 0, 0] 39 | if std is None: 40 | std = [0.2, 0.2, 0.2, 0.2] 41 | 42 | width = boxes[:, :, 2] - boxes[:, :, 0] 43 | height = boxes[:, :, 3] - boxes[:, :, 1] 44 | 45 | x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width 46 | y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height 47 | x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width 48 | y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height 49 | 50 | pred_boxes = keras.backend.stack([x1, y1, x2, y2], axis=2) 51 | 52 | return pred_boxes 53 | 54 | 55 | def shift(shape, stride, anchors): 56 | """ Produce shifted anchors based on shape of the map and stride size. 57 | 58 | Args 59 | shape : Shape to shift the anchors over. 60 | stride : Stride to shift the anchors with over the shape. 61 | anchors: The anchors to apply at each location. 62 | """ 63 | shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 64 | shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 65 | 66 | shift_x, shift_y = meshgrid(shift_x, shift_y) 67 | shift_x = keras.backend.reshape(shift_x, [-1]) 68 | shift_y = keras.backend.reshape(shift_y, [-1]) 69 | 70 | shifts = keras.backend.stack([ 71 | shift_x, 72 | shift_y, 73 | shift_x, 74 | shift_y 75 | ], axis=0) 76 | 77 | shifts = keras.backend.transpose(shifts) 78 | number_of_anchors = keras.backend.shape(anchors)[0] 79 | 80 | k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w 81 | 82 | shifted_anchors = keras.backend.reshape(anchors, [1, number_of_anchors, 4]) + keras.backend.cast(keras.backend.reshape(shifts, [k, 1, 4]), keras.backend.floatx()) 83 | shifted_anchors = keras.backend.reshape(shifted_anchors, [k * number_of_anchors, 4]) 84 | 85 | return shifted_anchors 86 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from ..utils.eval import evaluate 19 | 20 | 21 | class Evaluate(keras.callbacks.Callback): 22 | """ Evaluation callback for arbitrary datasets. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | generator, 28 | iou_threshold=0.5, 29 | score_threshold=0.05, 30 | max_detections=100, 31 | save_path=None, 32 | tensorboard=None, 33 | weighted_average=False, 34 | verbose=1 35 | ): 36 | """ Evaluate a given dataset using a given model at the end of every epoch during training. 37 | 38 | # Arguments 39 | generator : The generator that represents the dataset to evaluate. 40 | iou_threshold : The threshold used to consider when a detection is positive or negative. 41 | score_threshold : The score confidence threshold to use for detections. 42 | max_detections : The maximum number of detections to use per image. 43 | save_path : The path to save images with visualized detections to. 44 | tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value. 45 | weighted_average : Compute the mAP using the weighted average of precisions among classes. 46 | verbose : Set the verbosity level, by default this is set to 1. 47 | """ 48 | self.generator = generator 49 | self.iou_threshold = iou_threshold 50 | self.score_threshold = score_threshold 51 | self.max_detections = max_detections 52 | self.save_path = save_path 53 | self.tensorboard = tensorboard 54 | self.weighted_average = weighted_average 55 | self.verbose = verbose 56 | 57 | super(Evaluate, self).__init__() 58 | 59 | def on_epoch_end(self, epoch, logs=None): 60 | logs = logs or {} 61 | 62 | # run evaluation 63 | average_precisions = evaluate( 64 | self.generator, 65 | self.model, 66 | iou_threshold=self.iou_threshold, 67 | score_threshold=self.score_threshold, 68 | max_detections=self.max_detections, 69 | save_path=self.save_path 70 | ) 71 | 72 | # compute per class average precision 73 | total_instances = [] 74 | precisions = [] 75 | for label, (average_precision, num_annotations ) in average_precisions.items(): 76 | if self.verbose == 1: 77 | print('{:.0f} instances of class'.format(num_annotations), 78 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 79 | total_instances.append(num_annotations) 80 | precisions.append(average_precision) 81 | if self.weighted_average: 82 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 83 | else: 84 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 85 | 86 | if self.tensorboard is not None and self.tensorboard.writer is not None: 87 | import tensorflow as tf 88 | summary = tf.Summary() 89 | summary_value = summary.value.add() 90 | summary_value.simple_value = self.mean_ap 91 | summary_value.tag = "mAP" 92 | self.tensorboard.writer.add_summary(summary, epoch) 93 | 94 | logs['mAP'] = self.mean_ap 95 | 96 | if self.verbose == 1: 97 | print('mAP: {:.4f}'.format(self.mean_ap)) 98 | -------------------------------------------------------------------------------- /keras_retinanet/models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 cgratie (https://github.com/cgratie/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import keras 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | class VGGBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | def retinanet(self, *args, **kwargs): 31 | """ Returns a retinanet model using the correct backbone. 32 | """ 33 | return vgg_retinanet(*args, backbone=self.backbone, **kwargs) 34 | 35 | def download_imagenet(self): 36 | """ Downloads ImageNet weights and returns path to weights file. 37 | Weights can be downloaded at https://github.com/fizyr/keras-models/releases . 38 | """ 39 | if self.backbone == 'vgg16': 40 | resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP 41 | checksum = '6d6bbae143d832006294945121d1f1fc' 42 | elif self.backbone == 'vgg19': 43 | resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP 44 | checksum = '253f8cb515780f3b799900260a226db6' 45 | else: 46 | raise ValueError("Backbone '{}' not recognized.".format(self.backbone)) 47 | 48 | return get_file( 49 | '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone), 50 | resource, 51 | cache_subdir='models', 52 | file_hash=checksum 53 | ) 54 | 55 | def validate(self): 56 | """ Checks whether the backbone string is correct. 57 | """ 58 | allowed_backbones = ['vgg16', 'vgg19'] 59 | 60 | if self.backbone not in allowed_backbones: 61 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) 62 | 63 | def preprocess_image(self, inputs): 64 | """ Takes as input an image and prepares it for being passed through the network. 65 | """ 66 | return preprocess_image(inputs, mode='caffe') 67 | 68 | 69 | def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs): 70 | """ Constructs a retinanet model using a vgg backbone. 71 | 72 | Args 73 | num_classes: Number of classes to predict. 74 | backbone: Which backbone to use (one of ('vgg16', 'vgg19')). 75 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 76 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 77 | 78 | Returns 79 | RetinaNet model with a VGG backbone. 80 | """ 81 | # choose default input 82 | if inputs is None: 83 | inputs = keras.layers.Input(shape=(None, None, 3)) 84 | 85 | # create the vgg backbone 86 | if backbone == 'vgg16': 87 | vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None) 88 | elif backbone == 'vgg19': 89 | vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None) 90 | else: 91 | raise ValueError("Backbone '{}' not recognized.".format(backbone)) 92 | 93 | if modifier: 94 | vgg = modifier(vgg) 95 | 96 | # create the full model 97 | layer_names = ["block3_pool", "block4_pool", "block5_pool"] 98 | layer_outputs = [vgg.get_layer(name).output for name in layer_names] 99 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=layer_outputs, **kwargs) 100 | -------------------------------------------------------------------------------- /tests/backend/test_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy as np 18 | import keras 19 | import keras_retinanet.backend 20 | 21 | 22 | def test_bbox_transform_inv(): 23 | boxes = np.array([[ 24 | [100, 100, 200, 200], 25 | [100, 100, 300, 300], 26 | [100, 100, 200, 300], 27 | [100, 100, 300, 200], 28 | [80, 120, 200, 200], 29 | [80, 120, 300, 300], 30 | [80, 120, 200, 300], 31 | [80, 120, 300, 200], 32 | ]]) 33 | boxes = keras.backend.variable(boxes) 34 | 35 | deltas = np.array([[ 36 | [0 , 0 , 0 , 0 ], 37 | [0 , 0.1, 0 , 0 ], 38 | [-0.3, 0 , 0 , 0 ], 39 | [0.2 , 0.2, 0 , 0 ], 40 | [0 , 0 , 0.1 , 0 ], 41 | [0 , 0 , 0 , -0.3], 42 | [0 , 0 , 0.2 , 0.2 ], 43 | [0.1 , 0.2, -0.3, 0.4 ], 44 | ]]) 45 | deltas = keras.backend.variable(deltas) 46 | 47 | expected = np.array([[ 48 | [100 , 100 , 200 , 200 ], 49 | [100 , 104 , 300 , 300 ], 50 | [ 94 , 100 , 200 , 300 ], 51 | [108 , 104 , 300 , 200 ], 52 | [ 80 , 120 , 202.4 , 200 ], 53 | [ 80 , 120 , 300 , 289.2], 54 | [ 80 , 120 , 204.8 , 307.2], 55 | [ 84.4, 123.2, 286.8 , 206.4] 56 | ]]) 57 | 58 | result = keras_retinanet.backend.bbox_transform_inv(boxes, deltas) 59 | result = keras.backend.eval(result) 60 | 61 | np.testing.assert_array_almost_equal(result, expected, decimal=2) 62 | 63 | 64 | def test_shift(): 65 | shape = (2, 3) 66 | stride = 8 67 | 68 | anchors = np.array([ 69 | [-8, -8, 8, 8], 70 | [-16, -16, 16, 16], 71 | [-12, -12, 12, 12], 72 | [-12, -16, 12, 16], 73 | [-16, -12, 16, 12] 74 | ], dtype=keras.backend.floatx()) 75 | 76 | expected = [ 77 | # anchors for (0, 0) 78 | [4 - 8, 4 - 8, 4 + 8, 4 + 8], 79 | [4 - 16, 4 - 16, 4 + 16, 4 + 16], 80 | [4 - 12, 4 - 12, 4 + 12, 4 + 12], 81 | [4 - 12, 4 - 16, 4 + 12, 4 + 16], 82 | [4 - 16, 4 - 12, 4 + 16, 4 + 12], 83 | 84 | # anchors for (0, 1) 85 | [12 - 8, 4 - 8, 12 + 8, 4 + 8], 86 | [12 - 16, 4 - 16, 12 + 16, 4 + 16], 87 | [12 - 12, 4 - 12, 12 + 12, 4 + 12], 88 | [12 - 12, 4 - 16, 12 + 12, 4 + 16], 89 | [12 - 16, 4 - 12, 12 + 16, 4 + 12], 90 | 91 | # anchors for (0, 2) 92 | [20 - 8, 4 - 8, 20 + 8, 4 + 8], 93 | [20 - 16, 4 - 16, 20 + 16, 4 + 16], 94 | [20 - 12, 4 - 12, 20 + 12, 4 + 12], 95 | [20 - 12, 4 - 16, 20 + 12, 4 + 16], 96 | [20 - 16, 4 - 12, 20 + 16, 4 + 12], 97 | 98 | # anchors for (1, 0) 99 | [4 - 8, 12 - 8, 4 + 8, 12 + 8], 100 | [4 - 16, 12 - 16, 4 + 16, 12 + 16], 101 | [4 - 12, 12 - 12, 4 + 12, 12 + 12], 102 | [4 - 12, 12 - 16, 4 + 12, 12 + 16], 103 | [4 - 16, 12 - 12, 4 + 16, 12 + 12], 104 | 105 | # anchors for (1, 1) 106 | [12 - 8, 12 - 8, 12 + 8, 12 + 8], 107 | [12 - 16, 12 - 16, 12 + 16, 12 + 16], 108 | [12 - 12, 12 - 12, 12 + 12, 12 + 12], 109 | [12 - 12, 12 - 16, 12 + 12, 12 + 16], 110 | [12 - 16, 12 - 12, 12 + 16, 12 + 12], 111 | 112 | # anchors for (1, 2) 113 | [20 - 8, 12 - 8, 20 + 8, 12 + 8], 114 | [20 - 16, 12 - 16, 20 + 16, 12 + 16], 115 | [20 - 12, 12 - 12, 20 + 12, 12 + 12], 116 | [20 - 12, 12 - 16, 20 + 12, 12 + 16], 117 | [20 - 16, 12 - 12, 20 + 16, 12 + 12], 118 | ] 119 | 120 | result = keras_retinanet.backend.shift(shape, stride, anchors) 121 | result = keras.backend.eval(result) 122 | 123 | np.testing.assert_array_equal(result, expected) 124 | -------------------------------------------------------------------------------- /keras_retinanet/models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import densenet 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | allowed_backbones = { 27 | 'densenet121': ([6, 12, 24, 16], densenet.DenseNet121), 28 | 'densenet169': ([6, 12, 32, 32], densenet.DenseNet169), 29 | 'densenet201': ([6, 12, 48, 32], densenet.DenseNet201), 30 | } 31 | 32 | 33 | class DenseNetBackbone(Backbone): 34 | """ Describes backbone information and provides utility functions. 35 | """ 36 | 37 | def retinanet(self, *args, **kwargs): 38 | """ Returns a retinanet model using the correct backbone. 39 | """ 40 | return densenet_retinanet(*args, backbone=self.backbone, **kwargs) 41 | 42 | def download_imagenet(self): 43 | """ Download pre-trained weights for the specified backbone name. 44 | This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop 45 | where backbone is the densenet + number of layers (e.g. densenet121). 46 | For more info check the explanation from the keras densenet script itself: 47 | https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py 48 | """ 49 | origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' 50 | file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' 51 | 52 | # load weights 53 | if keras.backend.image_data_format() == 'channels_first': 54 | raise ValueError('Weights for "channels_first" format are not available.') 55 | 56 | weights_url = origin + file_name.format(self.backbone) 57 | return get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') 58 | 59 | def validate(self): 60 | """ Checks whether the backbone string is correct. 61 | """ 62 | backbone = self.backbone.split('_')[0] 63 | 64 | if backbone not in allowed_backbones: 65 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys())) 66 | 67 | def preprocess_image(self, inputs): 68 | """ Takes as input an image and prepares it for being passed through the network. 69 | """ 70 | return preprocess_image(inputs, mode='tf') 71 | 72 | 73 | def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs): 74 | """ Constructs a retinanet model using a densenet backbone. 75 | 76 | Args 77 | num_classes: Number of classes to predict. 78 | backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). 79 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 80 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 81 | 82 | Returns 83 | RetinaNet model with a DenseNet backbone. 84 | """ 85 | # choose default input 86 | if inputs is None: 87 | inputs = keras.layers.Input((None, None, 3)) 88 | 89 | blocks, creator = allowed_backbones[backbone] 90 | model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None) 91 | 92 | # get last conv layer from the end of each dense block 93 | layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)] 94 | 95 | # create the densenet backbone 96 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=model.name) 97 | 98 | # invoke modifier if given 99 | if modifier: 100 | model = modifier(model) 101 | 102 | # create the full model 103 | model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=model.outputs, **kwargs) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /keras_retinanet/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import cv2 18 | import numpy as np 19 | 20 | from .colors import label_color 21 | 22 | 23 | def draw_box(image, box, color, thickness=2): 24 | """ Draws a box on an image with a given color. 25 | 26 | # Arguments 27 | image : The image to draw on. 28 | box : A list of 4 elements (x1, y1, x2, y2). 29 | color : The color of the box. 30 | thickness : The thickness of the lines to draw a box with. 31 | """ 32 | b = np.array(box).astype(int) 33 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA) 34 | 35 | 36 | def draw_caption(image, box, caption): 37 | """ Draws a caption above the box in an image. 38 | 39 | # Arguments 40 | image : The image to draw on. 41 | box : A list of 4 elements (x1, y1, x2, y2). 42 | caption : String containing the text to draw. 43 | """ 44 | b = np.array(box).astype(int) 45 | cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2) 46 | cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1) 47 | 48 | 49 | def draw_boxes(image, boxes, color, thickness=2): 50 | """ Draws boxes on an image with a given color. 51 | 52 | # Arguments 53 | image : The image to draw on. 54 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 55 | color : The color of the boxes. 56 | thickness : The thickness of the lines to draw boxes with. 57 | """ 58 | for b in boxes: 59 | draw_box(image, b, color, thickness=thickness) 60 | 61 | 62 | def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5): 63 | """ Draws detections in an image. 64 | 65 | # Arguments 66 | image : The image to draw on. 67 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 68 | scores : A list of N classification scores. 69 | labels : A list of N labels. 70 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 71 | label_to_name : (optional) Functor for mapping a label to a name. 72 | score_threshold : Threshold used for determining what detections to draw. 73 | """ 74 | selection = np.where(scores > score_threshold)[0] 75 | 76 | for i in selection: 77 | c = color if color is not None else label_color(labels[i]) 78 | draw_box(image, boxes[i, :], color=c) 79 | 80 | # draw labels 81 | caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i]) 82 | draw_caption(image, boxes[i, :], caption) 83 | 84 | 85 | def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None): 86 | """ Draws annotations in an image. 87 | 88 | # Arguments 89 | image : The image to draw on. 90 | annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]). 91 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 92 | label_to_name : (optional) Functor for mapping a label to a name. 93 | """ 94 | if isinstance(annotations, np.ndarray): 95 | annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]} 96 | 97 | assert('bboxes' in annotations) 98 | assert('labels' in annotations) 99 | assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0]) 100 | 101 | for i in range(annotations['bboxes'].shape[0]): 102 | label = annotations['labels'][i] 103 | c = color if color is not None else label_color(label) 104 | caption = '{}'.format(label_to_name(label) if label_to_name else label) 105 | draw_caption(image, annotations['bboxes'][i], caption) 106 | draw_box(image, annotations['bboxes'][i], color=c) 107 | -------------------------------------------------------------------------------- /keras_retinanet/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import mobilenet 19 | from keras.utils import get_file 20 | from ..utils.image import preprocess_image 21 | 22 | from . import retinanet 23 | from . import Backbone 24 | 25 | 26 | class MobileNetBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224'] 31 | 32 | def retinanet(self, *args, **kwargs): 33 | """ Returns a retinanet model using the correct backbone. 34 | """ 35 | return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs) 36 | 37 | def download_imagenet(self): 38 | """ Download pre-trained weights for the specified backbone name. 39 | This name is in the format mobilenet{rows}_{alpha} where rows is the 40 | imagenet shape dimension and 'alpha' controls the width of the network. 41 | For more info check the explanation from the keras mobilenet script itself. 42 | """ 43 | 44 | alpha = float(self.backbone.split('_')[1]) 45 | rows = int(self.backbone.split('_')[0].replace('mobilenet', '')) 46 | 47 | # load weights 48 | if keras.backend.image_data_format() == 'channels_first': 49 | raise ValueError('Weights for "channels_last" format ' 50 | 'are not available.') 51 | if alpha == 1.0: 52 | alpha_text = '1_0' 53 | elif alpha == 0.75: 54 | alpha_text = '7_5' 55 | elif alpha == 0.50: 56 | alpha_text = '5_0' 57 | else: 58 | alpha_text = '2_5' 59 | 60 | model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows) 61 | weights_url = mobilenet.mobilenet.BASE_WEIGHT_PATH + model_name 62 | weights_path = get_file(model_name, weights_url, cache_subdir='models') 63 | 64 | return weights_path 65 | 66 | def validate(self): 67 | """ Checks whether the backbone string is correct. 68 | """ 69 | backbone = self.backbone.split('_')[0] 70 | 71 | if backbone not in MobileNetBackbone.allowed_backbones: 72 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones)) 73 | 74 | def preprocess_image(self, inputs): 75 | """ Takes as input an image and prepares it for being passed through the network. 76 | """ 77 | return preprocess_image(inputs, mode='tf') 78 | 79 | 80 | def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs): 81 | """ Constructs a retinanet model using a mobilenet backbone. 82 | 83 | Args 84 | num_classes: Number of classes to predict. 85 | backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')). 86 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 87 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 88 | 89 | Returns 90 | RetinaNet model with a MobileNet backbone. 91 | """ 92 | alpha = float(backbone.split('_')[1]) 93 | 94 | # choose default input 95 | if inputs is None: 96 | inputs = keras.layers.Input((None, None, 3)) 97 | 98 | backbone = mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None) 99 | 100 | # create the full model 101 | layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu'] 102 | layer_outputs = [backbone.get_layer(name).output for name in layer_names] 103 | backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name) 104 | 105 | # invoke modifier if given 106 | if modifier: 107 | backbone = modifier(backbone) 108 | 109 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone.outputs, **kwargs) 110 | -------------------------------------------------------------------------------- /keras_retinanet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | 5 | class Backbone(object): 6 | """ This class stores additional information on backbones. 7 | """ 8 | def __init__(self, backbone): 9 | # a dictionary mapping custom layer names to the correct classes 10 | from .. import layers 11 | from .. import losses 12 | from .. import initializers 13 | self.custom_objects = { 14 | 'UpsampleLike' : layers.UpsampleLike, 15 | 'PriorProbability' : initializers.PriorProbability, 16 | 'RegressBoxes' : layers.RegressBoxes, 17 | 'FilterDetections' : layers.FilterDetections, 18 | 'Anchors' : layers.Anchors, 19 | 'ClipBoxes' : layers.ClipBoxes, 20 | '_smooth_l1' : losses.smooth_l1(), 21 | '_focal' : losses.focal(), 22 | } 23 | 24 | self.backbone = backbone 25 | self.validate() 26 | 27 | def retinanet(self, *args, **kwargs): 28 | """ Returns a retinanet model using the correct backbone. 29 | """ 30 | raise NotImplementedError('retinanet method not implemented.') 31 | 32 | def download_imagenet(self): 33 | """ Downloads ImageNet weights and returns path to weights file. 34 | """ 35 | raise NotImplementedError('download_imagenet method not implemented.') 36 | 37 | def validate(self): 38 | """ Checks whether the backbone string is correct. 39 | """ 40 | raise NotImplementedError('validate method not implemented.') 41 | 42 | def preprocess_image(self, inputs): 43 | """ Takes as input an image and prepares it for being passed through the network. 44 | Having this function in Backbone allows other backbones to define a specific preprocessing step. 45 | """ 46 | raise NotImplementedError('preprocess_image method not implemented.') 47 | 48 | 49 | def backbone(backbone_name): 50 | """ Returns a backbone object for the given backbone. 51 | """ 52 | if 'resnet' in backbone_name: 53 | from .resnet import ResNetBackbone as b 54 | elif 'mobilenet' in backbone_name: 55 | from .mobilenet import MobileNetBackbone as b 56 | elif 'vgg' in backbone_name: 57 | from .vgg import VGGBackbone as b 58 | elif 'densenet' in backbone_name: 59 | from .densenet import DenseNetBackbone as b 60 | else: 61 | raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone)) 62 | 63 | return b(backbone_name) 64 | 65 | 66 | def load_model(filepath, backbone_name='resnet50'): 67 | """ Loads a retinanet model using the correct custom objects. 68 | 69 | Args 70 | filepath: one of the following: 71 | - string, path to the saved model, or 72 | - h5py.File object from which to load the model 73 | backbone_name : Backbone with which the model was trained. 74 | 75 | Returns 76 | A keras.models.Model object. 77 | 78 | Raises 79 | ImportError: if h5py is not available. 80 | ValueError: In case of an invalid savefile. 81 | """ 82 | import keras.models 83 | return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects) 84 | 85 | 86 | def convert_model(model, nms=True, class_specific_filter=True, anchor_params=None): 87 | """ Converts a training model to an inference model. 88 | 89 | Args 90 | model : A retinanet training model. 91 | nms : Boolean, whether to add NMS filtering to the converted model. 92 | class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. 93 | anchor_params : Anchor parameters object. If omitted, default values are used. 94 | 95 | Returns 96 | A keras.models.Model object. 97 | 98 | Raises 99 | ImportError: if h5py is not available. 100 | ValueError: In case of an invalid savefile. 101 | """ 102 | from .retinanet import retinanet_bbox 103 | return retinanet_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter, anchor_params=anchor_params) 104 | 105 | 106 | def assert_training_model(model): 107 | """ Assert that the model is a training model. 108 | """ 109 | assert(all(output in model.output_names for output in ['regression', 'classification'])), \ 110 | "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(model.output_names) 111 | 112 | 113 | def check_training_model(model): 114 | """ Check that model is a training model and exit otherwise. 115 | """ 116 | try: 117 | assert_training_model(model) 118 | except AssertionError as e: 119 | print(e, file=sys.stderr) 120 | sys.exit(1) 121 | -------------------------------------------------------------------------------- /keras_retinanet/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from . import backend 19 | 20 | 21 | def focal(alpha=0.25, gamma=2.0): 22 | """ Create a functor for computing the focal loss. 23 | 24 | Args 25 | alpha: Scale the focal weight with alpha. 26 | gamma: Take the power of the focal weight with gamma. 27 | 28 | Returns 29 | A functor that computes the focal loss using the alpha and gamma. 30 | """ 31 | def _focal(y_true, y_pred): 32 | """ Compute the focal loss given the target tensor and the predicted tensor. 33 | 34 | As defined in https://arxiv.org/abs/1708.02002 35 | 36 | Args 37 | y_true: Tensor of target data from the generator with shape (B, N, num_classes). 38 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 39 | 40 | Returns 41 | The focal loss of y_pred w.r.t. y_true. 42 | """ 43 | labels = y_true[:, :, :-1] 44 | anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object 45 | classification = y_pred 46 | 47 | # filter out "ignore" anchors 48 | indices = backend.where(keras.backend.not_equal(anchor_state, -1)) 49 | labels = backend.gather_nd(labels, indices) 50 | classification = backend.gather_nd(classification, indices) 51 | 52 | # compute the focal loss 53 | alpha_factor = keras.backend.ones_like(labels) * alpha 54 | alpha_factor = backend.where(keras.backend.equal(labels, 1), alpha_factor, 1 - alpha_factor) 55 | focal_weight = backend.where(keras.backend.equal(labels, 1), 1 - classification, classification) 56 | focal_weight = alpha_factor * focal_weight ** gamma 57 | 58 | cls_loss = focal_weight * keras.backend.binary_crossentropy(labels, classification) 59 | 60 | # compute the normalizer: the number of positive anchors 61 | normalizer = backend.where(keras.backend.equal(anchor_state, 1)) 62 | normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx()) 63 | normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer) 64 | 65 | return keras.backend.sum(cls_loss) / normalizer 66 | 67 | return _focal 68 | 69 | 70 | def smooth_l1(sigma=3.0): 71 | """ Create a smooth L1 loss functor. 72 | 73 | Args 74 | sigma: This argument defines the point where the loss changes from L2 to L1. 75 | 76 | Returns 77 | A functor for computing the smooth L1 loss given target data and predicted data. 78 | """ 79 | sigma_squared = sigma ** 2 80 | 81 | def _smooth_l1(y_true, y_pred): 82 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 83 | 84 | Args 85 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). 86 | y_pred: Tensor from the network of shape (B, N, 4). 87 | 88 | Returns 89 | The smooth L1 loss of y_pred w.r.t. y_true. 90 | """ 91 | # separate target and state 92 | regression = y_pred 93 | regression_target = y_true[:, :, :-1] 94 | anchor_state = y_true[:, :, -1] 95 | 96 | # filter out "ignore" anchors 97 | indices = backend.where(keras.backend.equal(anchor_state, 1)) 98 | regression = backend.gather_nd(regression, indices) 99 | regression_target = backend.gather_nd(regression_target, indices) 100 | 101 | # compute smooth L1 loss 102 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 103 | # |x| - 0.5 / sigma / sigma otherwise 104 | regression_diff = regression - regression_target 105 | regression_diff = keras.backend.abs(regression_diff) 106 | regression_loss = backend.where( 107 | keras.backend.less(regression_diff, 1.0 / sigma_squared), 108 | 0.5 * sigma_squared * keras.backend.pow(regression_diff, 2), 109 | regression_diff - 0.5 / sigma_squared 110 | ) 111 | 112 | # compute the normalizer: the number of positive anchors 113 | normalizer = keras.backend.maximum(1, keras.backend.shape(indices)[0]) 114 | normalizer = keras.backend.cast(normalizer, dtype=keras.backend.floatx()) 115 | return keras.backend.sum(regression_loss) / normalizer 116 | 117 | return _smooth_l1 118 | -------------------------------------------------------------------------------- /keras_retinanet/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.utils import get_file 19 | import keras_resnet 20 | import keras_resnet.models 21 | 22 | from . import retinanet 23 | from . import Backbone 24 | from ..utils.image import preprocess_image 25 | 26 | 27 | class ResNetBackbone(Backbone): 28 | """ Describes backbone information and provides utility functions. 29 | """ 30 | 31 | def __init__(self, backbone): 32 | super(ResNetBackbone, self).__init__(backbone) 33 | self.custom_objects.update(keras_resnet.custom_objects) 34 | 35 | def retinanet(self, *args, **kwargs): 36 | """ Returns a retinanet model using the correct backbone. 37 | """ 38 | return resnet_retinanet(*args, backbone=self.backbone, **kwargs) 39 | 40 | def download_imagenet(self): 41 | """ Downloads ImageNet weights and returns path to weights file. 42 | """ 43 | resnet_filename = 'ResNet-{}-model.keras.h5' 44 | resnet_resource = 'https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}'.format(resnet_filename) 45 | depth = int(self.backbone.replace('resnet', '')) 46 | 47 | filename = resnet_filename.format(depth) 48 | resource = resnet_resource.format(depth) 49 | if depth == 50: 50 | checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' 51 | elif depth == 101: 52 | checksum = '05dc86924389e5b401a9ea0348a3213c' 53 | elif depth == 152: 54 | checksum = '6ee11ef2b135592f8031058820bb9e71' 55 | 56 | return get_file( 57 | filename, 58 | resource, 59 | cache_subdir='models', 60 | md5_hash=checksum 61 | ) 62 | 63 | def validate(self): 64 | """ Checks whether the backbone string is correct. 65 | """ 66 | allowed_backbones = ['resnet50', 'resnet101', 'resnet152'] 67 | backbone = self.backbone.split('_')[0] 68 | 69 | if backbone not in allowed_backbones: 70 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) 71 | 72 | def preprocess_image(self, inputs): 73 | """ Takes as input an image and prepares it for being passed through the network. 74 | """ 75 | return preprocess_image(inputs, mode='caffe') 76 | 77 | 78 | def resnet_retinanet(num_classes, backbone='resnet50', inputs=None, modifier=None, **kwargs): 79 | """ Constructs a retinanet model using a resnet backbone. 80 | 81 | Args 82 | num_classes: Number of classes to predict. 83 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 84 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 85 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 86 | 87 | Returns 88 | RetinaNet model with a ResNet backbone. 89 | """ 90 | # choose default input 91 | if inputs is None: 92 | if keras.backend.image_data_format() == 'channels_first': 93 | inputs = keras.layers.Input(shape=(3, None, None)) 94 | else: 95 | inputs = keras.layers.Input(shape=(None, None, 3)) 96 | 97 | # create the resnet backbone 98 | if backbone == 'resnet50': 99 | resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True) 100 | elif backbone == 'resnet101': 101 | resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True) 102 | elif backbone == 'resnet152': 103 | resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True) 104 | else: 105 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 106 | 107 | # invoke modifier if given 108 | if modifier: 109 | resnet = modifier(resnet) 110 | 111 | # create the full model 112 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=resnet.outputs[1:], **kwargs) 113 | 114 | 115 | def resnet50_retinanet(num_classes, inputs=None, **kwargs): 116 | return resnet_retinanet(num_classes=num_classes, backbone='resnet50', inputs=inputs, **kwargs) 117 | 118 | 119 | def resnet101_retinanet(num_classes, inputs=None, **kwargs): 120 | return resnet_retinanet(num_classes=num_classes, backbone='resnet101', inputs=inputs, **kwargs) 121 | 122 | 123 | def resnet152_retinanet(num_classes, inputs=None, **kwargs): 124 | return resnet_retinanet(num_classes=num_classes, backbone='resnet152', inputs=inputs, **kwargs) 125 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | 23 | from pycocotools.coco import COCO 24 | 25 | 26 | class CocoGenerator(Generator): 27 | """ Generate data from the COCO dataset. 28 | 29 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. 30 | """ 31 | 32 | def __init__(self, data_dir, set_name, **kwargs): 33 | """ Initialize a COCO data generator. 34 | 35 | Args 36 | data_dir: Path to where the COCO dataset is stored. 37 | set_name: Name of the set to parse. 38 | """ 39 | self.data_dir = data_dir 40 | self.set_name = set_name 41 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) 42 | self.image_ids = self.coco.getImgIds() 43 | 44 | self.load_classes() 45 | 46 | super(CocoGenerator, self).__init__(**kwargs) 47 | 48 | def load_classes(self): 49 | """ Loads the class to label mapping (and inverse) for COCO. 50 | """ 51 | # load class names (name -> label) 52 | categories = self.coco.loadCats(self.coco.getCatIds()) 53 | categories.sort(key=lambda x: x['id']) 54 | 55 | self.classes = {} 56 | self.coco_labels = {} 57 | self.coco_labels_inverse = {} 58 | for c in categories: 59 | self.coco_labels[len(self.classes)] = c['id'] 60 | self.coco_labels_inverse[c['id']] = len(self.classes) 61 | self.classes[c['name']] = len(self.classes) 62 | 63 | # also load the reverse (label -> name) 64 | self.labels = {} 65 | for key, value in self.classes.items(): 66 | self.labels[value] = key 67 | 68 | def size(self): 69 | """ Size of the COCO dataset. 70 | """ 71 | return len(self.image_ids) 72 | 73 | def num_classes(self): 74 | """ Number of classes in the dataset. For COCO this is 80. 75 | """ 76 | return len(self.classes) 77 | 78 | def has_label(self, label): 79 | """ Return True if label is a known label. 80 | """ 81 | return label in self.labels 82 | 83 | def has_name(self, name): 84 | """ Returns True if name is a known class. 85 | """ 86 | return name in self.classes 87 | 88 | def name_to_label(self, name): 89 | """ Map name to label. 90 | """ 91 | return self.classes[name] 92 | 93 | def label_to_name(self, label): 94 | """ Map label to name. 95 | """ 96 | return self.labels[label] 97 | 98 | def coco_label_to_label(self, coco_label): 99 | """ Map COCO label to the label as used in the network. 100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. 101 | """ 102 | return self.coco_labels_inverse[coco_label] 103 | 104 | def coco_label_to_name(self, coco_label): 105 | """ Map COCO label to name. 106 | """ 107 | return self.label_to_name(self.coco_label_to_label(coco_label)) 108 | 109 | def label_to_coco_label(self, label): 110 | """ Map label as used by the network to labels as used by COCO. 111 | """ 112 | return self.coco_labels[label] 113 | 114 | def image_aspect_ratio(self, image_index): 115 | """ Compute the aspect ratio for an image with image_index. 116 | """ 117 | image = self.coco.loadImgs(self.image_ids[image_index])[0] 118 | return float(image['width']) / float(image['height']) 119 | 120 | def load_image(self, image_index): 121 | """ Load an image at the image_index. 122 | """ 123 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 124 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) 125 | return read_image_bgr(path) 126 | 127 | def load_annotations(self, image_index): 128 | """ Load annotations for an image_index. 129 | """ 130 | # get ground truth annotations 131 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 132 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 133 | 134 | # some images appear to miss annotations (like image with id 257034) 135 | if len(annotations_ids) == 0: 136 | return annotations 137 | 138 | # parse annotations 139 | coco_annotations = self.coco.loadAnns(annotations_ids) 140 | for idx, a in enumerate(coco_annotations): 141 | # some annotations have basically no width / height, skip them 142 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 143 | continue 144 | 145 | annotations['labels'] = np.concatenate([annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) 146 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ 147 | a['bbox'][0], 148 | a['bbox'][1], 149 | a['bbox'][0] + a['bbox'][2], 150 | a['bbox'][1] + a['bbox'][3], 151 | ]]], axis=0) 152 | 153 | return annotations 154 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import os.path 19 | 20 | import numpy as np 21 | from PIL import Image 22 | 23 | from .generator import Generator 24 | from ..utils.image import read_image_bgr 25 | 26 | kitti_classes = { 27 | 'Car': 0, 28 | 'Van': 1, 29 | 'Truck': 2, 30 | 'Pedestrian': 3, 31 | 'Person_sitting': 4, 32 | 'Cyclist': 5, 33 | 'Tram': 6, 34 | 'Misc': 7, 35 | 'DontCare': 7 36 | } 37 | 38 | 39 | class KittiGenerator(Generator): 40 | """ Generate data for a KITTI dataset. 41 | 42 | See http://www.cvlibs.net/datasets/kitti/ for more information. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | base_dir, 48 | subset='train', 49 | **kwargs 50 | ): 51 | """ Initialize a KITTI data generator. 52 | 53 | Args 54 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 55 | subset: The subset to generate data for (defaults to 'train'). 56 | """ 57 | self.base_dir = base_dir 58 | 59 | label_dir = os.path.join(self.base_dir, subset, 'labels') 60 | image_dir = os.path.join(self.base_dir, subset, 'images') 61 | 62 | """ 63 | 1 type Describes the type of object: 'Car', 'Van', 'Truck', 64 | 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 65 | 'Misc' or 'DontCare' 66 | 1 truncated Float from 0 (non-truncated) to 1 (truncated), where 67 | truncated refers to the object leaving image boundaries 68 | 1 occluded Integer (0,1,2,3) indicating occlusion state: 69 | 0 = fully visible, 1 = partly occluded 70 | 2 = largely occluded, 3 = unknown 71 | 1 alpha Observation angle of object, ranging [-pi..pi] 72 | 4 bbox 2D bounding box of object in the image (0-based index): 73 | contains left, top, right, bottom pixel coordinates 74 | 3 dimensions 3D object dimensions: height, width, length (in meters) 75 | 3 location 3D object location x,y,z in camera coordinates (in meters) 76 | 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] 77 | """ 78 | 79 | self.labels = {} 80 | self.classes = kitti_classes 81 | for name, label in self.classes.items(): 82 | self.labels[label] = name 83 | 84 | self.image_data = dict() 85 | self.images = [] 86 | for i, fn in enumerate(os.listdir(label_dir)): 87 | label_fp = os.path.join(label_dir, fn) 88 | image_fp = os.path.join(image_dir, fn.replace('.txt', '.png')) 89 | 90 | self.images.append(image_fp) 91 | 92 | fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl', 93 | 'lx', 'ly', 'lz', 'ry'] 94 | with open(label_fp, 'r') as csv_file: 95 | reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames) 96 | boxes = [] 97 | for line, row in enumerate(reader): 98 | label = row['type'] 99 | cls_id = kitti_classes[label] 100 | 101 | annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']} 102 | boxes.append(annotation) 103 | 104 | self.image_data[i] = boxes 105 | 106 | super(KittiGenerator, self).__init__(**kwargs) 107 | 108 | def size(self): 109 | """ Size of the dataset. 110 | """ 111 | return len(self.images) 112 | 113 | def num_classes(self): 114 | """ Number of classes in the dataset. 115 | """ 116 | return max(self.classes.values()) + 1 117 | 118 | def has_label(self, label): 119 | """ Return True if label is a known label. 120 | """ 121 | return label in self.labels 122 | 123 | def has_name(self, name): 124 | """ Returns True if name is a known class. 125 | """ 126 | return name in self.classes 127 | 128 | def name_to_label(self, name): 129 | """ Map name to label. 130 | """ 131 | raise NotImplementedError() 132 | 133 | def label_to_name(self, label): 134 | """ Map label to name. 135 | """ 136 | return self.labels[label] 137 | 138 | def image_aspect_ratio(self, image_index): 139 | """ Compute the aspect ratio for an image with image_index. 140 | """ 141 | # PIL is fast for metadata 142 | image = Image.open(self.images[image_index]) 143 | return float(image.width) / float(image.height) 144 | 145 | def load_image(self, image_index): 146 | """ Load an image at the image_index. 147 | """ 148 | return read_image_bgr(self.images[image_index]) 149 | 150 | def load_annotations(self, image_index): 151 | """ Load annotations for an image_index. 152 | """ 153 | image_data = self.image_data[image_index] 154 | annotations = {'labels': np.empty((len(image_data),)), 'bboxes': np.empty((len(image_data), 4))} 155 | 156 | for idx, ann in enumerate(image_data): 157 | annotations['bboxes'][idx, 0] = float(ann['x1']) 158 | annotations['bboxes'][idx, 1] = float(ann['y1']) 159 | annotations['bboxes'][idx, 2] = float(ann['x2']) 160 | annotations['bboxes'][idx, 3] = float(ann['y2']) 161 | annotations['labels'][idx] = int(ann['cls_id']) 162 | 163 | return annotations 164 | -------------------------------------------------------------------------------- /tests/utils/test_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_almost_equal 3 | from math import pi 4 | 5 | from keras_retinanet.utils.transform import ( 6 | colvec, 7 | transform_aabb, 8 | rotation, random_rotation, 9 | translation, random_translation, 10 | scaling, random_scaling, 11 | shear, random_shear, 12 | random_flip, 13 | random_transform, 14 | random_transform_generator, 15 | change_transform_origin, 16 | ) 17 | 18 | 19 | def test_colvec(): 20 | assert np.array_equal(colvec(0), np.array([[0]])) 21 | assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]])) 22 | assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]])) 23 | 24 | 25 | def test_rotation(): 26 | assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1))) 27 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1))) 28 | assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1))) 29 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1))) 30 | assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1))) 31 | 32 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1))) 33 | assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1))) 34 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1))) 35 | assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1))) 36 | assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1))) 37 | 38 | 39 | def test_random_rotation(): 40 | prng = np.random.RandomState(0) 41 | for i in range(100): 42 | assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng))) 43 | 44 | 45 | def test_translation(): 46 | assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1))) 47 | assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1))) 48 | assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1))) 49 | 50 | 51 | def assert_is_translation(transform, min, max): 52 | assert transform.shape == (3, 3) 53 | assert np.array_equal(transform[:, 0:2], np.eye(3, 2)) 54 | assert transform[2, 2] == 1 55 | assert np.greater_equal(transform[0:2, 2], min).all() 56 | assert np.less( transform[0:2, 2], max).all() 57 | 58 | 59 | def test_random_translation(): 60 | prng = np.random.RandomState(0) 61 | min = (-10, -20) 62 | max = (20, 10) 63 | for i in range(100): 64 | assert_is_translation(random_translation(min, max, prng), min, max) 65 | 66 | 67 | def test_shear(): 68 | assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1))) 69 | assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1))) 70 | assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1))) 71 | assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1))) 72 | assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1))) 73 | 74 | 75 | def assert_is_shear(transform): 76 | assert transform.shape == (3, 3) 77 | assert np.array_equal(transform[:, 0], [1, 0, 0]) 78 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 79 | assert transform[2, 1] == 0 80 | # sin^2 + cos^2 == 1 81 | assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2) 82 | 83 | 84 | def test_random_shear(): 85 | prng = np.random.RandomState(0) 86 | for i in range(100): 87 | assert_is_shear(random_shear(-pi, pi, prng)) 88 | 89 | 90 | def test_scaling(): 91 | assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1))) 92 | assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1))) 93 | assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1))) 94 | assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1))) 95 | 96 | 97 | def assert_is_scaling(transform, min, max): 98 | assert transform.shape == (3, 3) 99 | assert np.array_equal(transform[2, :], [0, 0, 1]) 100 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 101 | assert transform[1, 0] == 0 102 | assert transform[0, 1] == 0 103 | assert np.greater_equal(np.diagonal(transform)[:2], min).all() 104 | assert np.less( np.diagonal(transform)[:2], max).all() 105 | 106 | 107 | def test_random_scaling(): 108 | prng = np.random.RandomState(0) 109 | min = (0.1, 0.2) 110 | max = (20, 10) 111 | for i in range(100): 112 | assert_is_scaling(random_scaling(min, max, prng), min, max) 113 | 114 | 115 | def assert_is_flip(transform): 116 | assert transform.shape == (3, 3) 117 | assert np.array_equal(transform[2, :], [0, 0, 1]) 118 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 119 | assert transform[1, 0] == 0 120 | assert transform[0, 1] == 0 121 | assert abs(transform[0, 0]) == 1 122 | assert abs(transform[1, 1]) == 1 123 | 124 | 125 | def test_random_flip(): 126 | prng = np.random.RandomState(0) 127 | for i in range(100): 128 | assert_is_flip(random_flip(0.5, 0.5, prng)) 129 | 130 | 131 | def test_random_transform(): 132 | prng = np.random.RandomState(0) 133 | for i in range(100): 134 | transform = random_transform(prng=prng) 135 | assert np.array_equal(transform, np.identity(3)) 136 | 137 | for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())): 138 | assert np.array_equal(transform, np.identity(3)) 139 | 140 | 141 | def test_transform_aabb(): 142 | assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4])) 143 | assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4])) 144 | assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4])) 145 | 146 | 147 | def test_change_transform_origin(): 148 | assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4])) 149 | assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1))) 150 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1))) 151 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1))) 152 | -------------------------------------------------------------------------------- /tests/layers/test_filter_detections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import keras_retinanet.layers 19 | 20 | import numpy as np 21 | 22 | 23 | class TestFilterDetections(object): 24 | def test_simple(self): 25 | # create simple FilterDetections layer 26 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 27 | 28 | # create simple input 29 | boxes = np.array([[ 30 | [0, 0, 10, 10], 31 | [0, 0, 10, 10], # this will be suppressed 32 | ]], dtype=keras.backend.floatx()) 33 | boxes = keras.backend.constant(boxes) 34 | 35 | classification = np.array([[ 36 | [0, 0.9], # this will be suppressed 37 | [0, 1], 38 | ]], dtype=keras.backend.floatx()) 39 | classification = keras.backend.constant(classification) 40 | 41 | # compute output 42 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 43 | actual_boxes = keras.backend.eval(actual_boxes) 44 | actual_scores = keras.backend.eval(actual_scores) 45 | actual_labels = keras.backend.eval(actual_labels) 46 | 47 | # define expected output 48 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 49 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 50 | 51 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 52 | expected_scores[0, 0] = 1 53 | 54 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 55 | expected_labels[0, 0] = 1 56 | 57 | # assert actual and expected are equal 58 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 59 | np.testing.assert_array_equal(actual_scores, expected_scores) 60 | np.testing.assert_array_equal(actual_labels, expected_labels) 61 | 62 | def test_simple_with_other(self): 63 | # create simple FilterDetections layer 64 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 65 | 66 | # create simple input 67 | boxes = np.array([[ 68 | [0, 0, 10, 10], 69 | [0, 0, 10, 10], # this will be suppressed 70 | ]], dtype=keras.backend.floatx()) 71 | boxes = keras.backend.constant(boxes) 72 | 73 | classification = np.array([[ 74 | [0, 0.9], # this will be suppressed 75 | [0, 1], 76 | ]], dtype=keras.backend.floatx()) 77 | classification = keras.backend.constant(classification) 78 | 79 | other = [] 80 | other.append(np.array([[ 81 | [0, 1234], # this will be suppressed 82 | [0, 5678], 83 | ]], dtype=keras.backend.floatx())) 84 | other.append(np.array([[ 85 | 5678, # this will be suppressed 86 | 1234, 87 | ]], dtype=keras.backend.floatx())) 88 | other = [keras.backend.constant(o) for o in other] 89 | 90 | # compute output 91 | actual = filter_detections_layer.call([boxes, classification] + other) 92 | actual_boxes = keras.backend.eval(actual[0]) 93 | actual_scores = keras.backend.eval(actual[1]) 94 | actual_labels = keras.backend.eval(actual[2]) 95 | actual_other = [keras.backend.eval(a) for a in actual[3:]] 96 | 97 | # define expected output 98 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 99 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 100 | 101 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 102 | expected_scores[0, 0] = 1 103 | 104 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 105 | expected_labels[0, 0] = 1 106 | 107 | expected_other = [] 108 | expected_other.append(-1 * np.ones((1, 300, 2), dtype=keras.backend.floatx())) 109 | expected_other[-1][0, 0, :] = [0, 5678] 110 | expected_other.append(-1 * np.ones((1, 300), dtype=keras.backend.floatx())) 111 | expected_other[-1][0, 0] = 1234 112 | 113 | # assert actual and expected are equal 114 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 115 | np.testing.assert_array_equal(actual_scores, expected_scores) 116 | np.testing.assert_array_equal(actual_labels, expected_labels) 117 | 118 | for a, e in zip(actual_other, expected_other): 119 | np.testing.assert_array_equal(a, e) 120 | 121 | def test_mini_batch(self): 122 | # create simple FilterDetections layer 123 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 124 | 125 | # create input with batch_size=2 126 | boxes = np.array([ 127 | [ 128 | [0, 0, 10, 10], # this will be suppressed 129 | [0, 0, 10, 10], 130 | ], 131 | [ 132 | [100, 100, 150, 150], 133 | [100, 100, 150, 150], # this will be suppressed 134 | ], 135 | ], dtype=keras.backend.floatx()) 136 | boxes = keras.backend.constant(boxes) 137 | 138 | classification = np.array([ 139 | [ 140 | [0, 0.9], # this will be suppressed 141 | [0, 1], 142 | ], 143 | [ 144 | [1, 0], 145 | [0.9, 0], # this will be suppressed 146 | ], 147 | ], dtype=keras.backend.floatx()) 148 | classification = keras.backend.constant(classification) 149 | 150 | # compute output 151 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 152 | actual_boxes = keras.backend.eval(actual_boxes) 153 | actual_scores = keras.backend.eval(actual_scores) 154 | actual_labels = keras.backend.eval(actual_labels) 155 | 156 | # define expected output 157 | expected_boxes = -1 * np.ones((2, 300, 4), dtype=keras.backend.floatx()) 158 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 159 | expected_boxes[1, 0, :] = [100, 100, 150, 150] 160 | 161 | expected_scores = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 162 | expected_scores[0, 0] = 1 163 | expected_scores[1, 0] = 1 164 | 165 | expected_labels = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 166 | expected_labels[0, 0] = 1 167 | expected_labels[1, 0] = 0 168 | 169 | # assert actual and expected are equal 170 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 171 | np.testing.assert_array_equal(actual_scores, expected_scores) 172 | np.testing.assert_array_equal(actual_labels, expected_labels) 173 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | from six import raise_from 23 | from PIL import Image 24 | 25 | try: 26 | import xml.etree.cElementTree as ET 27 | except ImportError: 28 | import xml.etree.ElementTree as ET 29 | 30 | voc_classes = { 31 | 'aeroplane' : 0, 32 | 'bicycle' : 1, 33 | 'bird' : 2, 34 | 'boat' : 3, 35 | 'bottle' : 4, 36 | 'bus' : 5, 37 | 'car' : 6, 38 | 'cat' : 7, 39 | 'chair' : 8, 40 | 'cow' : 9, 41 | 'diningtable' : 10, 42 | 'dog' : 11, 43 | 'horse' : 12, 44 | 'motorbike' : 13, 45 | 'person' : 14, 46 | 'pottedplant' : 15, 47 | 'sheep' : 16, 48 | 'sofa' : 17, 49 | 'train' : 18, 50 | 'tvmonitor' : 19 51 | } 52 | 53 | 54 | def _findNode(parent, name, debug_name=None, parse=None): 55 | if debug_name is None: 56 | debug_name = name 57 | 58 | result = parent.find(name) 59 | if result is None: 60 | raise ValueError('missing element \'{}\''.format(debug_name)) 61 | if parse is not None: 62 | try: 63 | return parse(result.text) 64 | except ValueError as e: 65 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) 66 | return result 67 | 68 | 69 | class PascalVocGenerator(Generator): 70 | """ Generate data for a Pascal VOC dataset. 71 | 72 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | data_dir, 78 | set_name, 79 | classes=voc_classes, 80 | image_extension='.jpg', 81 | skip_truncated=False, 82 | skip_difficult=False, 83 | **kwargs 84 | ): 85 | """ Initialize a Pascal VOC data generator. 86 | 87 | Args 88 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 89 | csv_class_file: Path to the CSV classes file. 90 | """ 91 | self.data_dir = data_dir 92 | self.set_name = set_name 93 | self.classes = classes 94 | self.image_names = [l.strip().split(None, 1)[0] for l in open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] 95 | self.image_extension = image_extension 96 | self.skip_truncated = skip_truncated 97 | self.skip_difficult = skip_difficult 98 | 99 | self.labels = {} 100 | for key, value in self.classes.items(): 101 | self.labels[value] = key 102 | 103 | super(PascalVocGenerator, self).__init__(**kwargs) 104 | 105 | def size(self): 106 | """ Size of the dataset. 107 | """ 108 | return len(self.image_names) 109 | 110 | def num_classes(self): 111 | """ Number of classes in the dataset. 112 | """ 113 | return len(self.classes) 114 | 115 | def has_label(self, label): 116 | """ Return True if label is a known label. 117 | """ 118 | return label in self.labels 119 | 120 | def has_name(self, name): 121 | """ Returns True if name is a known class. 122 | """ 123 | return name in self.classes 124 | 125 | def name_to_label(self, name): 126 | """ Map name to label. 127 | """ 128 | return self.classes[name] 129 | 130 | def label_to_name(self, label): 131 | """ Map label to name. 132 | """ 133 | return self.labels[label] 134 | 135 | def image_aspect_ratio(self, image_index): 136 | """ Compute the aspect ratio for an image with image_index. 137 | """ 138 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 139 | image = Image.open(path) 140 | return float(image.width) / float(image.height) 141 | 142 | def load_image(self, image_index): 143 | """ Load an image at the image_index. 144 | """ 145 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 146 | return read_image_bgr(path) 147 | 148 | def __parse_annotation(self, element): 149 | """ Parse an annotation given an XML element. 150 | """ 151 | truncated = _findNode(element, 'truncated', parse=int) 152 | difficult = _findNode(element, 'difficult', parse=int) 153 | 154 | class_name = _findNode(element, 'name').text 155 | if class_name not in self.classes: 156 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) 157 | 158 | box = np.zeros((4,)) 159 | label = self.name_to_label(class_name) 160 | 161 | bndbox = _findNode(element, 'bndbox') 162 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 163 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 164 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 165 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 166 | 167 | return truncated, difficult, box, label 168 | 169 | def __parse_annotations(self, xml_root): 170 | """ Parse all annotations under the xml_root. 171 | """ 172 | annotations = {'labels': np.empty((len(xml_root.findall('object')),)), 'bboxes': np.empty((len(xml_root.findall('object')), 4))} 173 | for i, element in enumerate(xml_root.iter('object')): 174 | try: 175 | truncated, difficult, box, label = self.__parse_annotation(element) 176 | except ValueError as e: 177 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) 178 | 179 | if truncated and self.skip_truncated: 180 | continue 181 | if difficult and self.skip_difficult: 182 | continue 183 | 184 | annotations['bboxes'][i, :] = box 185 | annotations['labels'][i] = label 186 | 187 | return annotations 188 | 189 | def load_annotations(self, image_index): 190 | """ Load annotations for an image_index. 191 | """ 192 | filename = self.image_names[image_index] + '.xml' 193 | try: 194 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) 195 | return self.__parse_annotations(tree.getroot()) 196 | except ET.ParseError as e: 197 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 198 | except ValueError as e: 199 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 200 | -------------------------------------------------------------------------------- /keras_retinanet/utils/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import division 18 | import numpy as np 19 | import cv2 20 | from PIL import Image 21 | 22 | from .transform import change_transform_origin 23 | 24 | 25 | def read_image_bgr(path): 26 | """ Read an image in BGR format. 27 | 28 | Args 29 | path: Path to the image. 30 | """ 31 | image = np.asarray(Image.open(path).convert('RGB')) 32 | return image[:, :, ::-1].copy() 33 | 34 | 35 | def preprocess_image(x, mode='caffe'): 36 | """ Preprocess an image by subtracting the ImageNet mean. 37 | 38 | Args 39 | x: np.array of shape (None, None, 3) or (3, None, None). 40 | mode: One of "caffe" or "tf". 41 | - caffe: will zero-center each color channel with 42 | respect to the ImageNet dataset, without scaling. 43 | - tf: will scale pixels between -1 and 1, sample-wise. 44 | 45 | Returns 46 | The input with the ImageNet mean subtracted. 47 | """ 48 | # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py" 49 | # except for converting RGB -> BGR since we assume BGR already 50 | 51 | # covert always to float32 to keep compatibility with opencv 52 | x = x.astype(np.float32) 53 | 54 | if mode == 'tf': 55 | x /= 127.5 56 | x -= 1. 57 | elif mode == 'caffe': 58 | x[..., 0] -= 103.939 59 | x[..., 1] -= 116.779 60 | x[..., 2] -= 123.68 61 | 62 | return x 63 | 64 | 65 | def adjust_transform_for_image(transform, image, relative_translation): 66 | """ Adjust a transformation for a specific image. 67 | 68 | The translation of the matrix will be scaled with the size of the image. 69 | The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image. 70 | """ 71 | height, width, channels = image.shape 72 | 73 | result = transform 74 | 75 | # Scale the translation with the image size if specified. 76 | if relative_translation: 77 | result[0:2, 2] *= [width, height] 78 | 79 | # Move the origin of transformation. 80 | result = change_transform_origin(transform, (0.5 * width, 0.5 * height)) 81 | 82 | return result 83 | 84 | 85 | class TransformParameters: 86 | """ Struct holding parameters determining how to apply a transformation to an image. 87 | 88 | Args 89 | fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap' 90 | interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4' 91 | cval: Fill value to use with fill_mode='constant' 92 | relative_translation: If true (the default), interpret translation as a factor of the image size. 93 | If false, interpret it as absolute pixels. 94 | """ 95 | def __init__( 96 | self, 97 | fill_mode = 'nearest', 98 | interpolation = 'linear', 99 | cval = 0, 100 | relative_translation = True, 101 | ): 102 | self.fill_mode = fill_mode 103 | self.cval = cval 104 | self.interpolation = interpolation 105 | self.relative_translation = relative_translation 106 | 107 | def cvBorderMode(self): 108 | if self.fill_mode == 'constant': 109 | return cv2.BORDER_CONSTANT 110 | if self.fill_mode == 'nearest': 111 | return cv2.BORDER_REPLICATE 112 | if self.fill_mode == 'reflect': 113 | return cv2.BORDER_REFLECT_101 114 | if self.fill_mode == 'wrap': 115 | return cv2.BORDER_WRAP 116 | 117 | def cvInterpolation(self): 118 | if self.interpolation == 'nearest': 119 | return cv2.INTER_NEAREST 120 | if self.interpolation == 'linear': 121 | return cv2.INTER_LINEAR 122 | if self.interpolation == 'cubic': 123 | return cv2.INTER_CUBIC 124 | if self.interpolation == 'area': 125 | return cv2.INTER_AREA 126 | if self.interpolation == 'lanczos4': 127 | return cv2.INTER_LANCZOS4 128 | 129 | 130 | def apply_transform(matrix, image, params): 131 | """ 132 | Apply a transformation to an image. 133 | 134 | The origin of transformation is at the top left corner of the image. 135 | 136 | The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image. 137 | Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space. 138 | 139 | Args 140 | matrix: A homogeneous 3 by 3 matrix holding representing the transformation to apply. 141 | image: The image to transform. 142 | params: The transform parameters (see TransformParameters) 143 | """ 144 | output = cv2.warpAffine( 145 | image, 146 | matrix[:2, :], 147 | dsize = (image.shape[1], image.shape[0]), 148 | flags = params.cvInterpolation(), 149 | borderMode = params.cvBorderMode(), 150 | borderValue = params.cval, 151 | ) 152 | return output 153 | 154 | 155 | def compute_resize_scale(image_shape, min_side=800, max_side=1333): 156 | """ Compute an image scale such that the image size is constrained to min_side and max_side. 157 | 158 | Args 159 | min_side: The image's min side will be equal to min_side after resizing. 160 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. 161 | 162 | Returns 163 | A resizing scale. 164 | """ 165 | (rows, cols, _) = image_shape 166 | 167 | smallest_side = min(rows, cols) 168 | 169 | # rescale the image so the smallest side is min_side 170 | scale = min_side / smallest_side 171 | 172 | # check if the largest side is now greater than max_side, which can happen 173 | # when images have a large aspect ratio 174 | largest_side = max(rows, cols) 175 | if largest_side * scale > max_side: 176 | scale = max_side / largest_side 177 | 178 | return scale 179 | 180 | 181 | def resize_image(img, min_side=800, max_side=1333): 182 | """ Resize an image such that the size is constrained to min_side and max_side. 183 | 184 | Args 185 | min_side: The image's min side will be equal to min_side after resizing. 186 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. 187 | 188 | Returns 189 | A resized image. 190 | """ 191 | # compute scale to resize the image 192 | scale = compute_resize_scale(img.shape, min_side=min_side, max_side=max_side) 193 | 194 | # resize the image with the computed scale 195 | img = cv2.resize(img, None, fx=scale, fy=scale) 196 | 197 | return img, scale 198 | -------------------------------------------------------------------------------- /keras_retinanet/layers/_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from .. import backend 19 | from ..utils import anchors as utils_anchors 20 | 21 | import numpy as np 22 | 23 | 24 | class Anchors(keras.layers.Layer): 25 | """ Keras layer for generating achors for a given shape. 26 | """ 27 | 28 | def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs): 29 | """ Initializer for an Anchors layer. 30 | 31 | Args 32 | size: The base size of the anchors to generate. 33 | stride: The stride of the anchors to generate. 34 | ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios). 35 | scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales). 36 | """ 37 | self.size = size 38 | self.stride = stride 39 | self.ratios = ratios 40 | self.scales = scales 41 | 42 | if ratios is None: 43 | self.ratios = utils_anchors.AnchorParameters.default.ratios 44 | elif isinstance(ratios, list): 45 | self.ratios = np.array(ratios) 46 | if scales is None: 47 | self.scales = utils_anchors.AnchorParameters.default.scales 48 | elif isinstance(scales, list): 49 | self.scales = np.array(scales) 50 | 51 | self.num_anchors = len(ratios) * len(scales) 52 | self.anchors = keras.backend.variable(utils_anchors.generate_anchors( 53 | base_size=size, 54 | ratios=ratios, 55 | scales=scales, 56 | )) 57 | 58 | super(Anchors, self).__init__(*args, **kwargs) 59 | 60 | def call(self, inputs, **kwargs): 61 | features = inputs 62 | features_shape = keras.backend.shape(features) 63 | 64 | # generate proposals from bbox deltas and shifted anchors 65 | if keras.backend.image_data_format() == 'channels_first': 66 | anchors = backend.shift(features_shape[2:4], self.stride, self.anchors) 67 | else: 68 | anchors = backend.shift(features_shape[1:3], self.stride, self.anchors) 69 | anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1)) 70 | 71 | return anchors 72 | 73 | def compute_output_shape(self, input_shape): 74 | if None not in input_shape[1:]: 75 | if keras.backend.image_data_format() == 'channels_first': 76 | total = np.prod(input_shape[2:4]) * self.num_anchors 77 | else: 78 | total = np.prod(input_shape[1:3]) * self.num_anchors 79 | 80 | return (input_shape[0], total, 4) 81 | else: 82 | return (input_shape[0], None, 4) 83 | 84 | def get_config(self): 85 | config = super(Anchors, self).get_config() 86 | config.update({ 87 | 'size' : self.size, 88 | 'stride' : self.stride, 89 | 'ratios' : self.ratios.tolist(), 90 | 'scales' : self.scales.tolist(), 91 | }) 92 | 93 | return config 94 | 95 | 96 | class UpsampleLike(keras.layers.Layer): 97 | """ Keras layer for upsampling a Tensor to be the same shape as another Tensor. 98 | """ 99 | 100 | def call(self, inputs, **kwargs): 101 | source, target = inputs 102 | target_shape = keras.backend.shape(target) 103 | if keras.backend.image_data_format() == 'channels_first': 104 | source = backend.transpose(source, (0, 2, 3, 1)) 105 | output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest') 106 | output = backend.transpose(output, (0, 3, 1, 2)) 107 | return output 108 | else: 109 | return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest') 110 | 111 | def compute_output_shape(self, input_shape): 112 | if keras.backend.image_data_format() == 'channels_first': 113 | return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4] 114 | else: 115 | return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],) 116 | 117 | 118 | class RegressBoxes(keras.layers.Layer): 119 | """ Keras layer for applying regression values to boxes. 120 | """ 121 | 122 | def __init__(self, mean=None, std=None, *args, **kwargs): 123 | """ Initializer for the RegressBoxes layer. 124 | 125 | Args 126 | mean: The mean value of the regression values which was used for normalization. 127 | std: The standard value of the regression values which was used for normalization. 128 | """ 129 | if mean is None: 130 | mean = np.array([0, 0, 0, 0]) 131 | if std is None: 132 | std = np.array([0.2, 0.2, 0.2, 0.2]) 133 | 134 | if isinstance(mean, (list, tuple)): 135 | mean = np.array(mean) 136 | elif not isinstance(mean, np.ndarray): 137 | raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) 138 | 139 | if isinstance(std, (list, tuple)): 140 | std = np.array(std) 141 | elif not isinstance(std, np.ndarray): 142 | raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) 143 | 144 | self.mean = mean 145 | self.std = std 146 | super(RegressBoxes, self).__init__(*args, **kwargs) 147 | 148 | def call(self, inputs, **kwargs): 149 | anchors, regression = inputs 150 | return backend.bbox_transform_inv(anchors, regression, mean=self.mean, std=self.std) 151 | 152 | def compute_output_shape(self, input_shape): 153 | return input_shape[0] 154 | 155 | def get_config(self): 156 | config = super(RegressBoxes, self).get_config() 157 | config.update({ 158 | 'mean': self.mean.tolist(), 159 | 'std' : self.std.tolist(), 160 | }) 161 | 162 | return config 163 | 164 | 165 | class ClipBoxes(keras.layers.Layer): 166 | """ Keras layer to clip box values to lie inside a given shape. 167 | """ 168 | 169 | def call(self, inputs, **kwargs): 170 | image, boxes = inputs 171 | shape = keras.backend.cast(keras.backend.shape(image), keras.backend.floatx()) 172 | if keras.backend.image_data_format() == 'channels_first': 173 | height = shape[2] 174 | width = shape[3] 175 | else: 176 | height = shape[1] 177 | width = shape[2] 178 | x1 = backend.clip_by_value(boxes[:, :, 0], 0, width) 179 | y1 = backend.clip_by_value(boxes[:, :, 1], 0, height) 180 | x2 = backend.clip_by_value(boxes[:, :, 2], 0, width) 181 | y2 = backend.clip_by_value(boxes[:, :, 3], 0, height) 182 | 183 | return keras.backend.stack([x1, y1, x2, y2], axis=2) 184 | 185 | def compute_output_shape(self, input_shape): 186 | return input_shape[1] 187 | -------------------------------------------------------------------------------- /tests/layers/test_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import keras_retinanet.layers 19 | 20 | import numpy as np 21 | 22 | 23 | class TestAnchors(object): 24 | def test_simple(self): 25 | # create simple Anchors layer 26 | anchors_layer = keras_retinanet.layers.Anchors( 27 | size=32, 28 | stride=8, 29 | ratios=np.array([1], keras.backend.floatx()), 30 | scales=np.array([1], keras.backend.floatx()), 31 | ) 32 | 33 | # create fake features input (only shape is used anyway) 34 | features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx()) 35 | features = keras.backend.variable(features) 36 | 37 | # call the Anchors layer 38 | anchors = anchors_layer.call(features) 39 | anchors = keras.backend.eval(anchors) 40 | 41 | # expected anchor values 42 | expected = np.array([[ 43 | [-12, -12, 20, 20], 44 | [-4 , -12, 28, 20], 45 | [-12, -4 , 20, 28], 46 | [-4 , -4 , 28, 28], 47 | ]], dtype=keras.backend.floatx()) 48 | 49 | # test anchor values 50 | np.testing.assert_array_equal(anchors, expected) 51 | 52 | # mark test to fail 53 | def test_mini_batch(self): 54 | # create simple Anchors layer 55 | anchors_layer = keras_retinanet.layers.Anchors( 56 | size=32, 57 | stride=8, 58 | ratios=np.array([1], dtype=keras.backend.floatx()), 59 | scales=np.array([1], dtype=keras.backend.floatx()), 60 | ) 61 | 62 | # create fake features input with batch_size=2 63 | features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx()) 64 | features = keras.backend.variable(features) 65 | 66 | # call the Anchors layer 67 | anchors = anchors_layer.call(features) 68 | anchors = keras.backend.eval(anchors) 69 | 70 | # expected anchor values 71 | expected = np.array([[ 72 | [-12, -12, 20, 20], 73 | [-4 , -12, 28, 20], 74 | [-12, -4 , 20, 28], 75 | [-4 , -4 , 28, 28], 76 | ]], dtype=keras.backend.floatx()) 77 | expected = np.tile(expected, (2, 1, 1)) 78 | 79 | # test anchor values 80 | np.testing.assert_array_equal(anchors, expected) 81 | 82 | 83 | class TestUpsampleLike(object): 84 | def test_simple(self): 85 | # create simple UpsampleLike layer 86 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 87 | 88 | # create input source 89 | source = np.zeros((1, 2, 2, 1), dtype=keras.backend.floatx()) 90 | source = keras.backend.variable(source) 91 | target = np.zeros((1, 5, 5, 1), dtype=keras.backend.floatx()) 92 | expected = target 93 | target = keras.backend.variable(target) 94 | 95 | # compute output 96 | actual = upsample_like_layer.call([source, target]) 97 | actual = keras.backend.eval(actual) 98 | 99 | np.testing.assert_array_equal(actual, expected) 100 | 101 | def test_mini_batch(self): 102 | # create simple UpsampleLike layer 103 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 104 | 105 | # create input source 106 | source = np.zeros((2, 2, 2, 1), dtype=keras.backend.floatx()) 107 | source = keras.backend.variable(source) 108 | 109 | target = np.zeros((2, 5, 5, 1), dtype=keras.backend.floatx()) 110 | expected = target 111 | target = keras.backend.variable(target) 112 | 113 | # compute output 114 | actual = upsample_like_layer.call([source, target]) 115 | actual = keras.backend.eval(actual) 116 | 117 | np.testing.assert_array_equal(actual, expected) 118 | 119 | 120 | class TestRegressBoxes(object): 121 | def test_simple(self): 122 | mean = [0, 0, 0, 0] 123 | std = [0.2, 0.2, 0.2, 0.2] 124 | 125 | # create simple RegressBoxes layer 126 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 127 | 128 | # create input 129 | anchors = np.array([[ 130 | [0 , 0 , 10 , 10 ], 131 | [50, 50, 100, 100], 132 | [20, 20, 40 , 40 ], 133 | ]], dtype=keras.backend.floatx()) 134 | anchors = keras.backend.variable(anchors) 135 | 136 | regression = np.array([[ 137 | [0 , 0 , 0 , 0 ], 138 | [0.1, 0.1, 0 , 0 ], 139 | [0 , 0 , 0.1, 0.1], 140 | ]], dtype=keras.backend.floatx()) 141 | regression = keras.backend.variable(regression) 142 | 143 | # compute output 144 | actual = regress_boxes_layer.call([anchors, regression]) 145 | actual = keras.backend.eval(actual) 146 | 147 | # compute expected output 148 | expected = np.array([[ 149 | [0 , 0 , 10 , 10 ], 150 | [51, 51, 100 , 100 ], 151 | [20, 20, 40.4, 40.4], 152 | ]], dtype=keras.backend.floatx()) 153 | 154 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 155 | 156 | # mark test to fail 157 | def test_mini_batch(self): 158 | mean = [0, 0, 0, 0] 159 | std = [0.2, 0.2, 0.2, 0.2] 160 | 161 | # create simple RegressBoxes layer 162 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 163 | 164 | # create input 165 | anchors = np.array([ 166 | [ 167 | [0 , 0 , 10 , 10 ], # 1 168 | [50, 50, 100, 100], # 2 169 | [20, 20, 40 , 40 ], # 3 170 | ], 171 | [ 172 | [20, 20, 40 , 40 ], # 3 173 | [0 , 0 , 10 , 10 ], # 1 174 | [50, 50, 100, 100], # 2 175 | ], 176 | ], dtype=keras.backend.floatx()) 177 | anchors = keras.backend.variable(anchors) 178 | 179 | regression = np.array([ 180 | [ 181 | [0 , 0 , 0 , 0 ], # 1 182 | [0.1, 0.1, 0 , 0 ], # 2 183 | [0 , 0 , 0.1, 0.1], # 3 184 | ], 185 | [ 186 | [0 , 0 , 0.1, 0.1], # 3 187 | [0 , 0 , 0 , 0 ], # 1 188 | [0.1, 0.1, 0 , 0 ], # 2 189 | ], 190 | ], dtype=keras.backend.floatx()) 191 | regression = keras.backend.variable(regression) 192 | 193 | # compute output 194 | actual = regress_boxes_layer.call([anchors, regression]) 195 | actual = keras.backend.eval(actual) 196 | 197 | # compute expected output 198 | expected = np.array([ 199 | [ 200 | [0 , 0 , 10 , 10 ], # 1 201 | [51, 51, 100 , 100 ], # 2 202 | [20, 20, 40.4, 40.4], # 3 203 | ], 204 | [ 205 | [20, 20, 40.4, 40.4], # 3 206 | [0 , 0 , 10 , 10 ], # 1 207 | [51, 51, 100 , 100 ], # 2 208 | ], 209 | ], dtype=keras.backend.floatx()) 210 | 211 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 212 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/csv_generator.py: -------------------------------------------------------------------------------- 1 | from .generator import Generator 2 | from ..utils.image import read_image_bgr 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from six import raise_from 7 | 8 | import csv 9 | import sys 10 | import os.path 11 | 12 | 13 | def _parse(value, function, fmt): 14 | """ 15 | Parse a string into a value, and format a nice ValueError if it fails. 16 | 17 | Returns `function(value)`. 18 | Any `ValueError` raised is catched and a new `ValueError` is raised 19 | with message `fmt.format(e)`, where `e` is the caught `ValueError`. 20 | """ 21 | try: 22 | return function(value) 23 | except ValueError as e: 24 | raise_from(ValueError(fmt.format(e)), None) 25 | 26 | 27 | def _read_classes(csv_reader): 28 | """ Parse the classes file given by csv_reader. 29 | """ 30 | result = {} 31 | for line, row in enumerate(csv_reader): 32 | line += 1 33 | 34 | try: 35 | class_name, class_id = row 36 | except ValueError: 37 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) 38 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) 39 | 40 | if class_name in result: 41 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) 42 | result[class_name] = class_id 43 | return result 44 | 45 | 46 | def _read_annotations(csv_reader, classes): 47 | """ Read annotations from the csv_reader. 48 | """ 49 | result = {} 50 | for line, row in enumerate(csv_reader): 51 | line += 1 52 | 53 | try: 54 | img_file, x1, y1, x2, y2, class_name = row[:6] 55 | except ValueError: 56 | raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None) 57 | 58 | if img_file not in result: 59 | result[img_file] = [] 60 | 61 | # If a row contains only an image path, it's an image without annotations. 62 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): 63 | continue 64 | 65 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) 66 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) 67 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) 68 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) 69 | 70 | # Check that the bounding box is valid. 71 | if x2 <= x1: 72 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) 73 | if y2 <= y1: 74 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) 75 | 76 | # check if the current class name is correctly present 77 | if class_name not in classes: 78 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) 79 | 80 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) 81 | return result 82 | 83 | 84 | def _open_for_csv(path): 85 | """ Open a file with flags suitable for csv.reader. 86 | 87 | This is different for python2 it means with mode 'rb', 88 | for python3 this means 'r' with "universal newlines". 89 | """ 90 | if sys.version_info[0] < 3: 91 | return open(path, 'rb') 92 | else: 93 | return open(path, 'r', newline='') 94 | 95 | 96 | class CSVGenerator(Generator): 97 | """ Generate data for a custom CSV dataset. 98 | 99 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | csv_data_file, 105 | csv_class_file, 106 | base_dir=None, 107 | **kwargs 108 | ): 109 | """ Initialize a CSV data generator. 110 | 111 | Args 112 | csv_data_file: Path to the CSV annotations file. 113 | csv_class_file: Path to the CSV classes file. 114 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 115 | """ 116 | self.image_names = [] 117 | self.image_data = {} 118 | self.base_dir = base_dir 119 | 120 | # Take base_dir from annotations file if not explicitly specified. 121 | if self.base_dir is None: 122 | self.base_dir = os.path.dirname(csv_data_file) 123 | 124 | # parse the provided class file 125 | try: 126 | with _open_for_csv(csv_class_file) as file: 127 | self.classes = _read_classes(csv.reader(file, delimiter=',')) 128 | except ValueError as e: 129 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) 130 | 131 | self.labels = {} 132 | for key, value in self.classes.items(): 133 | self.labels[value] = key 134 | 135 | # csv with img_path, x1, y1, x2, y2, class_name 136 | try: 137 | with _open_for_csv(csv_data_file) as file: 138 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) 139 | except ValueError as e: 140 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) 141 | self.image_names = list(self.image_data.keys()) 142 | 143 | super(CSVGenerator, self).__init__(**kwargs) 144 | 145 | def size(self): 146 | """ Size of the dataset. 147 | """ 148 | return len(self.image_names) 149 | 150 | def num_classes(self): 151 | """ Number of classes in the dataset. 152 | """ 153 | return max(self.classes.values()) + 1 154 | 155 | def has_label(self, label): 156 | """ Return True if label is a known label. 157 | """ 158 | return label in self.labels 159 | 160 | def has_name(self, name): 161 | """ Returns True if name is a known class. 162 | """ 163 | return name in self.classes 164 | 165 | def name_to_label(self, name): 166 | """ Map name to label. 167 | """ 168 | return self.classes[name] 169 | 170 | def label_to_name(self, label): 171 | """ Map label to name. 172 | """ 173 | return self.labels[label] 174 | 175 | def image_path(self, image_index): 176 | """ Returns the image path for image_index. 177 | """ 178 | return os.path.join(self.base_dir, self.image_names[image_index]) 179 | 180 | def image_aspect_ratio(self, image_index): 181 | """ Compute the aspect ratio for an image with image_index. 182 | """ 183 | # PIL is fast for metadata 184 | image = Image.open(self.image_path(image_index)) 185 | return float(image.width) / float(image.height) 186 | 187 | def load_image(self, image_index): 188 | """ Load an image at the image_index. 189 | """ 190 | return read_image_bgr(self.image_path(image_index)) 191 | 192 | def load_annotations(self, image_index): 193 | """ Load annotations for an image_index. 194 | """ 195 | path = self.image_names[image_index] 196 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 197 | 198 | for idx, annot in enumerate(self.image_data[path]): 199 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) 200 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ 201 | float(annot['x1']), 202 | float(annot['y1']), 203 | float(annot['x2']), 204 | float(annot['y2']), 205 | ]])) 206 | 207 | return annotations 208 | -------------------------------------------------------------------------------- /tests/preprocessing/test_csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import pytest 19 | try: 20 | from io import StringIO 21 | except ImportError: 22 | from stringio import StringIO 23 | 24 | from keras_retinanet.preprocessing import csv_generator 25 | 26 | 27 | def csv_str(string): 28 | if str == bytes: 29 | string = string.decode('utf-8') 30 | return csv.reader(StringIO(string)) 31 | 32 | 33 | def annotation(x1, y1, x2, y2, class_name): 34 | return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'class': class_name} 35 | 36 | 37 | def test_read_classes(): 38 | assert csv_generator._read_classes(csv_str('')) == {} 39 | assert csv_generator._read_classes(csv_str('a,1')) == {'a': 1} 40 | assert csv_generator._read_classes(csv_str('a,1\nb,2')) == {'a': 1, 'b': 2} 41 | 42 | 43 | def test_read_classes_wrong_format(): 44 | with pytest.raises(ValueError): 45 | try: 46 | csv_generator._read_classes(csv_str('a,b,c')) 47 | except ValueError as e: 48 | assert str(e).startswith('line 1: format should be') 49 | raise 50 | with pytest.raises(ValueError): 51 | try: 52 | csv_generator._read_classes(csv_str('a,1\nb,c,d')) 53 | except ValueError as e: 54 | assert str(e).startswith('line 2: format should be') 55 | raise 56 | 57 | 58 | def test_read_classes_malformed_class_id(): 59 | with pytest.raises(ValueError): 60 | try: 61 | csv_generator._read_classes(csv_str('a,b')) 62 | except ValueError as e: 63 | assert str(e).startswith("line 1: malformed class ID:") 64 | raise 65 | 66 | with pytest.raises(ValueError): 67 | try: 68 | csv_generator._read_classes(csv_str('a,1\nb,c')) 69 | except ValueError as e: 70 | assert str(e).startswith('line 2: malformed class ID:') 71 | raise 72 | 73 | 74 | def test_read_classes_duplicate_name(): 75 | with pytest.raises(ValueError): 76 | try: 77 | csv_generator._read_classes(csv_str('a,1\nb,2\na,3')) 78 | except ValueError as e: 79 | assert str(e).startswith('line 3: duplicate class name') 80 | raise 81 | 82 | 83 | def test_read_annotations(): 84 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 85 | annotations = csv_generator._read_annotations(csv_str( 86 | 'a.png,0,1,2,3,a' '\n' 87 | 'b.png,4,5,6,7,b' '\n' 88 | 'c.png,8,9,10,11,c' '\n' 89 | 'd.png,12,13,14,15,d' '\n' 90 | ), classes) 91 | assert annotations == { 92 | 'a.png': [annotation( 0, 1, 2, 3, 'a')], 93 | 'b.png': [annotation( 4, 5, 6, 7, 'b')], 94 | 'c.png': [annotation( 8, 9, 10, 11, 'c')], 95 | 'd.png': [annotation(12, 13, 14, 15, 'd')], 96 | } 97 | 98 | 99 | def test_read_annotations_multiple(): 100 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 101 | annotations = csv_generator._read_annotations(csv_str( 102 | 'a.png,0,1,2,3,a' '\n' 103 | 'b.png,4,5,6,7,b' '\n' 104 | 'a.png,8,9,10,11,c' '\n' 105 | ), classes) 106 | assert annotations == { 107 | 'a.png': [ 108 | annotation(0, 1, 2, 3, 'a'), 109 | annotation(8, 9, 10, 11, 'c'), 110 | ], 111 | 'b.png': [annotation(4, 5, 6, 7, 'b')], 112 | } 113 | 114 | 115 | def test_read_annotations_wrong_format(): 116 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 117 | with pytest.raises(ValueError): 118 | try: 119 | csv_generator._read_annotations(csv_str('a.png,1,2,3,a'), classes) 120 | except ValueError as e: 121 | assert str(e).startswith("line 1: format should be") 122 | raise 123 | 124 | with pytest.raises(ValueError): 125 | try: 126 | csv_generator._read_annotations(csv_str( 127 | 'a.png,0,1,2,3,a' '\n' 128 | 'a.png,1,2,3,a' '\n' 129 | ), classes) 130 | except ValueError as e: 131 | assert str(e).startswith("line 2: format should be") 132 | raise 133 | 134 | 135 | def test_read_annotations_wrong_x1(): 136 | with pytest.raises(ValueError): 137 | try: 138 | csv_generator._read_annotations(csv_str('a.png,a,0,1,2,a'), {'a': 1}) 139 | except ValueError as e: 140 | assert str(e).startswith("line 1: malformed x1:") 141 | raise 142 | 143 | 144 | def test_read_annotations_wrong_y1(): 145 | with pytest.raises(ValueError): 146 | try: 147 | csv_generator._read_annotations(csv_str('a.png,0,a,1,2,a'), {'a': 1}) 148 | except ValueError as e: 149 | assert str(e).startswith("line 1: malformed y1:") 150 | raise 151 | 152 | 153 | def test_read_annotations_wrong_x2(): 154 | with pytest.raises(ValueError): 155 | try: 156 | csv_generator._read_annotations(csv_str('a.png,0,1,a,2,a'), {'a': 1}) 157 | except ValueError as e: 158 | assert str(e).startswith("line 1: malformed x2:") 159 | raise 160 | 161 | 162 | def test_read_annotations_wrong_y2(): 163 | with pytest.raises(ValueError): 164 | try: 165 | csv_generator._read_annotations(csv_str('a.png,0,1,2,a,a'), {'a': 1}) 166 | except ValueError as e: 167 | assert str(e).startswith("line 1: malformed y2:") 168 | raise 169 | 170 | 171 | def test_read_annotations_wrong_class(): 172 | with pytest.raises(ValueError): 173 | try: 174 | csv_generator._read_annotations(csv_str('a.png,0,1,2,3,g'), {'a': 1}) 175 | except ValueError as e: 176 | assert str(e).startswith("line 1: unknown class name:") 177 | raise 178 | 179 | 180 | def test_read_annotations_invalid_bb_x(): 181 | with pytest.raises(ValueError): 182 | try: 183 | csv_generator._read_annotations(csv_str('a.png,1,2,1,3,g'), {'a': 1}) 184 | except ValueError as e: 185 | assert str(e).startswith("line 1: x2 (1) must be higher than x1 (1)") 186 | raise 187 | with pytest.raises(ValueError): 188 | try: 189 | csv_generator._read_annotations(csv_str('a.png,9,2,5,3,g'), {'a': 1}) 190 | except ValueError as e: 191 | assert str(e).startswith("line 1: x2 (5) must be higher than x1 (9)") 192 | raise 193 | 194 | 195 | def test_read_annotations_invalid_bb_y(): 196 | with pytest.raises(ValueError): 197 | try: 198 | csv_generator._read_annotations(csv_str('a.png,1,2,3,2,a'), {'a': 1}) 199 | except ValueError as e: 200 | assert str(e).startswith("line 1: y2 (2) must be higher than y1 (2)") 201 | raise 202 | with pytest.raises(ValueError): 203 | try: 204 | csv_generator._read_annotations(csv_str('a.png,1,8,3,5,a'), {'a': 1}) 205 | except ValueError as e: 206 | assert str(e).startswith("line 1: y2 (5) must be higher than y1 (8)") 207 | raise 208 | 209 | 210 | def test_read_annotations_empty_image(): 211 | # Check that images without annotations are parsed. 212 | assert csv_generator._read_annotations(csv_str('a.png,,,,,\nb.png,,,,,'), {'a': 1}) == {'a.png': [], 'b.png': []} 213 | 214 | # Check that lines without annotations don't clear earlier annotations. 215 | assert csv_generator._read_annotations(csv_str('a.png,0,1,2,3,a\na.png,,,,,'), {'a': 1}) == {'a.png': [annotation(0, 1, 2, 3, 'a')]} 216 | -------------------------------------------------------------------------------- /keras_retinanet/bin/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Copyright 2017-2018 Fizyr (https://fizyr.com) 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import argparse 18 | import os 19 | import sys 20 | 21 | import keras 22 | import tensorflow as tf 23 | 24 | # Allow relative imports when being executed as script. 25 | if __name__ == "__main__" and __package__ is None: 26 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 27 | import keras_retinanet.bin # noqa: F401 28 | __package__ = "keras_retinanet.bin" 29 | 30 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 31 | from .. import models 32 | from ..preprocessing.csv_generator import CSVGenerator 33 | from ..preprocessing.pascal_voc import PascalVocGenerator 34 | from ..utils.config import read_config_file, parse_anchor_parameters 35 | from ..utils.eval import evaluate 36 | from ..utils.keras_version import check_keras_version 37 | 38 | 39 | def get_session(): 40 | """ Construct a modified tf session. 41 | """ 42 | config = tf.ConfigProto() 43 | config.gpu_options.allow_growth = True 44 | return tf.Session(config=config) 45 | 46 | 47 | def create_generator(args): 48 | """ Create generators for evaluation. 49 | """ 50 | if args.dataset_type == 'coco': 51 | # import here to prevent unnecessary dependency on cocoapi 52 | from ..preprocessing.coco import CocoGenerator 53 | 54 | validation_generator = CocoGenerator( 55 | args.coco_path, 56 | 'val2017', 57 | image_min_side=args.image_min_side, 58 | image_max_side=args.image_max_side, 59 | config=args.config 60 | ) 61 | elif args.dataset_type == 'pascal': 62 | validation_generator = PascalVocGenerator( 63 | args.pascal_path, 64 | 'test', 65 | image_min_side=args.image_min_side, 66 | image_max_side=args.image_max_side, 67 | config=args.config 68 | ) 69 | elif args.dataset_type == 'csv': 70 | validation_generator = CSVGenerator( 71 | args.annotations, 72 | args.classes, 73 | image_min_side=args.image_min_side, 74 | image_max_side=args.image_max_side, 75 | config=args.config 76 | ) 77 | else: 78 | raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) 79 | 80 | return validation_generator 81 | 82 | 83 | def parse_args(args): 84 | """ Parse the arguments. 85 | """ 86 | parser = argparse.ArgumentParser(description='Evaluation script for a RetinaNet network.') 87 | subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') 88 | subparsers.required = True 89 | 90 | coco_parser = subparsers.add_parser('coco') 91 | coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') 92 | 93 | pascal_parser = subparsers.add_parser('pascal') 94 | pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') 95 | 96 | csv_parser = subparsers.add_parser('csv') 97 | csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.') 98 | csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') 99 | 100 | parser.add_argument('model', help='Path to RetinaNet model.') 101 | parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true') 102 | parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50') 103 | parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).') 104 | parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float) 105 | parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float) 106 | parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int) 107 | parser.add_argument('--save-path', help='Path for saving images with detections (doesn\'t work for COCO).') 108 | parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800) 109 | parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333) 110 | parser.add_argument('--config', help='Path to a configuration parameters .ini file (only used with --convert-model).') 111 | 112 | return parser.parse_args(args) 113 | 114 | 115 | def main(args=None): 116 | # parse arguments 117 | if args is None: 118 | args = sys.argv[1:] 119 | args = parse_args(args) 120 | 121 | # make sure keras is the minimum required version 122 | check_keras_version() 123 | 124 | # optionally choose specific GPU 125 | if args.gpu: 126 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 127 | keras.backend.tensorflow_backend.set_session(get_session()) 128 | 129 | # make save path if it doesn't exist 130 | if args.save_path is not None and not os.path.exists(args.save_path): 131 | os.makedirs(args.save_path) 132 | 133 | # optionally load config parameters 134 | if args.config: 135 | args.config = read_config_file(args.config) 136 | 137 | # create the generator 138 | generator = create_generator(args) 139 | 140 | # optionally load anchor parameters 141 | anchor_params = None 142 | if args.config and 'anchor_parameters' in args.config: 143 | anchor_params = parse_anchor_parameters(args.config) 144 | 145 | # load the model 146 | print('Loading model, this may take a second...') 147 | model = models.load_model(args.model, backbone_name=args.backbone) 148 | 149 | # optionally convert the model 150 | if args.convert_model: 151 | model = models.convert_model(model, anchor_params=anchor_params) 152 | 153 | # print model summary 154 | # print(model.summary()) 155 | 156 | # start evaluation 157 | if args.dataset_type == 'coco': 158 | from ..utils.coco_eval import evaluate_coco 159 | evaluate_coco(generator, model, args.score_threshold) 160 | else: 161 | average_precisions = evaluate( 162 | generator, 163 | model, 164 | iou_threshold=args.iou_threshold, 165 | score_threshold=args.score_threshold, 166 | max_detections=args.max_detections, 167 | save_path=args.save_path 168 | ) 169 | 170 | # print evaluation 171 | total_instances = [] 172 | precisions = [] 173 | for label, (average_precision, num_annotations) in average_precisions.items(): 174 | print('{:.0f} instances of class'.format(num_annotations), 175 | generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 176 | total_instances.append(num_annotations) 177 | precisions.append(average_precision) 178 | 179 | if sum(total_instances) == 0: 180 | print('No test instances found.') 181 | return 182 | 183 | print('mAP using the weighted average of precisions among classes: {:.4f}'.format(sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances))) 184 | print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances))) 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /tests/utils/test_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import configparser 3 | import keras 4 | 5 | from keras_retinanet.utils.anchors import anchors_for_shape, AnchorParameters 6 | from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters 7 | 8 | 9 | def test_config_read(): 10 | config = read_config_file('tests/test-data/config/config.ini') 11 | assert 'anchor_parameters' in config 12 | assert 'sizes' in config['anchor_parameters'] 13 | assert 'strides' in config['anchor_parameters'] 14 | assert 'ratios' in config['anchor_parameters'] 15 | assert 'scales' in config['anchor_parameters'] 16 | assert config['anchor_parameters']['sizes'] == '32 64 128 256 512' 17 | assert config['anchor_parameters']['strides'] == '8 16 32 64 128' 18 | assert config['anchor_parameters']['ratios'] == '0.5 1 2 3' 19 | assert config['anchor_parameters']['scales'] == '1 1.2 1.6' 20 | 21 | 22 | def create_anchor_params_config(): 23 | config = configparser.ConfigParser() 24 | config['anchor_parameters'] = {} 25 | config['anchor_parameters']['sizes'] = '32 64 128 256 512' 26 | config['anchor_parameters']['strides'] = '8 16 32 64 128' 27 | config['anchor_parameters']['ratios'] = '0.5 1' 28 | config['anchor_parameters']['scales'] = '1 1.2 1.6' 29 | 30 | return config 31 | 32 | 33 | def test_parse_anchor_parameters(): 34 | config = create_anchor_params_config() 35 | anchor_params_parsed = parse_anchor_parameters(config) 36 | 37 | sizes = [32, 64, 128, 256, 512] 38 | strides = [8, 16, 32, 64, 128] 39 | ratios = np.array([0.5, 1], keras.backend.floatx()) 40 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 41 | 42 | assert sizes == anchor_params_parsed.sizes 43 | assert strides == anchor_params_parsed.strides 44 | np.testing.assert_equal(ratios, anchor_params_parsed.ratios) 45 | np.testing.assert_equal(scales, anchor_params_parsed.scales) 46 | 47 | 48 | def test_anchors_for_shape_dimensions(): 49 | sizes = [32, 64, 128] 50 | strides = [8, 16, 32] 51 | ratios = np.array([0.5, 1, 2, 3], keras.backend.floatx()) 52 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 53 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 54 | 55 | pyramid_levels = [3, 4, 5] 56 | image_shape = (64, 64) 57 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 58 | 59 | assert all_anchors.shape == (1008, 4) 60 | 61 | 62 | def test_anchors_for_shape_values(): 63 | sizes = [12] 64 | strides = [8] 65 | ratios = np.array([1, 2], keras.backend.floatx()) 66 | scales = np.array([1, 2], keras.backend.floatx()) 67 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 68 | 69 | pyramid_levels = [3] 70 | image_shape = (16, 16) 71 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 72 | 73 | # using almost_equal for floating point imprecisions 74 | np.testing.assert_almost_equal(all_anchors[0, :], [ 75 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 76 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 77 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 78 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 79 | ], decimal=6) 80 | np.testing.assert_almost_equal(all_anchors[1, :], [ 81 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 82 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 83 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 84 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 85 | ], decimal=6) 86 | np.testing.assert_almost_equal(all_anchors[2, :], [ 87 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 88 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 89 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 90 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 91 | ], decimal=6) 92 | np.testing.assert_almost_equal(all_anchors[3, :], [ 93 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 94 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 95 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 96 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 97 | ], decimal=6) 98 | np.testing.assert_almost_equal(all_anchors[4, :], [ 99 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 100 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 101 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 102 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 103 | ], decimal=6) 104 | np.testing.assert_almost_equal(all_anchors[5, :], [ 105 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 106 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 107 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 108 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 109 | ], decimal=6) 110 | np.testing.assert_almost_equal(all_anchors[6, :], [ 111 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 112 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 113 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 114 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 115 | ], decimal=6) 116 | np.testing.assert_almost_equal(all_anchors[7, :], [ 117 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 118 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 119 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 120 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 121 | ], decimal=6) 122 | np.testing.assert_almost_equal(all_anchors[8, :], [ 123 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 124 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 125 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 126 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 127 | ], decimal=6) 128 | np.testing.assert_almost_equal(all_anchors[9, :], [ 129 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 130 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 131 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 132 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 133 | ], decimal=6) 134 | np.testing.assert_almost_equal(all_anchors[10, :], [ 135 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 136 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 137 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 138 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 139 | ], decimal=6) 140 | np.testing.assert_almost_equal(all_anchors[11, :], [ 141 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 142 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 143 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 144 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 145 | ], decimal=6) 146 | np.testing.assert_almost_equal(all_anchors[12, :], [ 147 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 148 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 149 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 150 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 151 | ], decimal=6) 152 | np.testing.assert_almost_equal(all_anchors[13, :], [ 153 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 154 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 155 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 156 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 157 | ], decimal=6) 158 | np.testing.assert_almost_equal(all_anchors[14, :], [ 159 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 160 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 161 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 162 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 163 | ], decimal=6) 164 | np.testing.assert_almost_equal(all_anchors[15, :], [ 165 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 166 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 167 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 168 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 169 | ], decimal=6) 170 | -------------------------------------------------------------------------------- /keras_retinanet/utils/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from .anchors import compute_overlap 18 | from .visualization import draw_detections, draw_annotations 19 | 20 | import keras 21 | import numpy as np 22 | import os 23 | 24 | import cv2 25 | import progressbar 26 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 27 | 28 | 29 | def _compute_ap(recall, precision): 30 | """ Compute the average precision, given the recall and precision curves. 31 | 32 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 33 | 34 | # Arguments 35 | recall: The recall curve (list). 36 | precision: The precision curve (list). 37 | # Returns 38 | The average precision as computed in py-faster-rcnn. 39 | """ 40 | # correct AP calculation 41 | # first append sentinel values at the end 42 | mrec = np.concatenate(([0.], recall, [1.])) 43 | mpre = np.concatenate(([0.], precision, [0.])) 44 | 45 | # compute the precision envelope 46 | for i in range(mpre.size - 1, 0, -1): 47 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 48 | 49 | # to calculate area under PR curve, look for points 50 | # where X axis (recall) changes value 51 | i = np.where(mrec[1:] != mrec[:-1])[0] 52 | 53 | # and sum (\Delta recall) * prec 54 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 55 | return ap 56 | 57 | 58 | def _get_detections(generator, model, score_threshold=0.05, max_detections=100, save_path=None): 59 | """ Get the detections from the model using the generator. 60 | 61 | The result is a list of lists such that the size is: 62 | all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes] 63 | 64 | # Arguments 65 | generator : The generator used to run images through the model. 66 | model : The model to run on the images. 67 | score_threshold : The score confidence threshold to use. 68 | max_detections : The maximum number of detections to use per image. 69 | save_path : The path to save the images with visualized detections to. 70 | # Returns 71 | A list of lists containing the detections for each image in the generator. 72 | """ 73 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())] 74 | 75 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): 76 | raw_image = generator.load_image(i) 77 | image = generator.preprocess_image(raw_image.copy()) 78 | image, scale = generator.resize_image(image) 79 | 80 | if keras.backend.image_data_format() == 'channels_first': 81 | image = image.transpose((2, 0, 1)) 82 | 83 | # run network 84 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3] 85 | 86 | # correct boxes for image scale 87 | boxes /= scale 88 | 89 | # select indices which have a score above the threshold 90 | indices = np.where(scores[0, :] > score_threshold)[0] 91 | 92 | # select those scores 93 | scores = scores[0][indices] 94 | 95 | # find the order with which to sort the scores 96 | scores_sort = np.argsort(-scores)[:max_detections] 97 | 98 | # select detections 99 | image_boxes = boxes[0, indices[scores_sort], :] 100 | image_scores = scores[scores_sort] 101 | image_labels = labels[0, indices[scores_sort]] 102 | image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1) 103 | 104 | if save_path is not None: 105 | draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name) 106 | draw_detections(raw_image, image_boxes, image_scores, image_labels, label_to_name=generator.label_to_name) 107 | 108 | cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) 109 | 110 | # copy detections to all_detections 111 | for label in range(generator.num_classes()): 112 | if not generator.has_label(label): 113 | continue 114 | 115 | all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1] 116 | 117 | return all_detections 118 | 119 | 120 | def _get_annotations(generator): 121 | """ Get the ground truth annotations from the generator. 122 | 123 | The result is a list of lists such that the size is: 124 | all_detections[num_images][num_classes] = annotations[num_detections, 5] 125 | 126 | # Arguments 127 | generator : The generator used to retrieve ground truth annotations. 128 | # Returns 129 | A list of lists containing the annotations for each image in the generator. 130 | """ 131 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 132 | 133 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '): 134 | # load the annotations 135 | annotations = generator.load_annotations(i) 136 | 137 | # copy detections to all_annotations 138 | for label in range(generator.num_classes()): 139 | if not generator.has_label(label): 140 | continue 141 | 142 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() 143 | 144 | return all_annotations 145 | 146 | 147 | def evaluate( 148 | generator, 149 | model, 150 | iou_threshold=0.5, 151 | score_threshold=0.05, 152 | max_detections=100, 153 | save_path=None 154 | ): 155 | """ Evaluate a given dataset using a given model. 156 | 157 | # Arguments 158 | generator : The generator that represents the dataset to evaluate. 159 | model : The model to evaluate. 160 | iou_threshold : The threshold used to consider when a detection is positive or negative. 161 | score_threshold : The score confidence threshold to use for detections. 162 | max_detections : The maximum number of detections to use per image. 163 | save_path : The path to save images with visualized detections to. 164 | # Returns 165 | A dict mapping class names to mAP scores. 166 | """ 167 | # gather all detections and annotations 168 | all_detections = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path) 169 | all_annotations = _get_annotations(generator) 170 | average_precisions = {} 171 | 172 | # all_detections = pickle.load(open('all_detections.pkl', 'rb')) 173 | # all_annotations = pickle.load(open('all_annotations.pkl', 'rb')) 174 | # pickle.dump(all_detections, open('all_detections.pkl', 'wb')) 175 | # pickle.dump(all_annotations, open('all_annotations.pkl', 'wb')) 176 | 177 | # process detections and annotations 178 | for label in range(generator.num_classes()): 179 | if not generator.has_label(label): 180 | continue 181 | 182 | false_positives = np.zeros((0,)) 183 | true_positives = np.zeros((0,)) 184 | scores = np.zeros((0,)) 185 | num_annotations = 0.0 186 | 187 | for i in range(generator.size()): 188 | detections = all_detections[i][label] 189 | annotations = all_annotations[i][label] 190 | num_annotations += annotations.shape[0] 191 | detected_annotations = [] 192 | 193 | for d in detections: 194 | scores = np.append(scores, d[4]) 195 | 196 | if annotations.shape[0] == 0: 197 | false_positives = np.append(false_positives, 1) 198 | true_positives = np.append(true_positives, 0) 199 | continue 200 | 201 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations) 202 | assigned_annotation = np.argmax(overlaps, axis=1) 203 | max_overlap = overlaps[0, assigned_annotation] 204 | 205 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: 206 | false_positives = np.append(false_positives, 0) 207 | true_positives = np.append(true_positives, 1) 208 | detected_annotations.append(assigned_annotation) 209 | else: 210 | false_positives = np.append(false_positives, 1) 211 | true_positives = np.append(true_positives, 0) 212 | 213 | # no annotations -> AP for this class is 0 (is this correct?) 214 | if num_annotations == 0: 215 | average_precisions[label] = 0, 0 216 | continue 217 | 218 | # sort by score 219 | indices = np.argsort(-scores) 220 | false_positives = false_positives[indices] 221 | true_positives = true_positives[indices] 222 | 223 | # compute false positives and true positives 224 | false_positives = np.cumsum(false_positives) 225 | true_positives = np.cumsum(true_positives) 226 | 227 | # compute recall and precision 228 | recall = true_positives / num_annotations 229 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) 230 | 231 | # compute average precision 232 | average_precision = _compute_ap(recall, precision) 233 | average_precisions[label] = average_precision, num_annotations 234 | 235 | return average_precisions 236 | -------------------------------------------------------------------------------- /tests/preprocessing/test_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from keras_retinanet.preprocessing.generator import Generator 18 | 19 | import numpy as np 20 | import pytest 21 | 22 | 23 | class SimpleGenerator(Generator): 24 | def __init__(self, bboxes, labels, num_classes=0, image=None): 25 | assert(len(bboxes) == len(labels)) 26 | self.bboxes = bboxes 27 | self.labels = labels 28 | self.num_classes_ = num_classes 29 | self.image = image 30 | super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False) 31 | 32 | def num_classes(self): 33 | return self.num_classes_ 34 | 35 | def load_image(self, image_index): 36 | return self.image 37 | 38 | def size(self): 39 | return len(self.bboxes) 40 | 41 | def load_annotations(self, image_index): 42 | annotations = {'labels': self.labels[image_index], 'bboxes': self.bboxes[image_index]} 43 | return annotations 44 | 45 | 46 | class TestLoadAnnotationsGroup(object): 47 | def test_simple(self): 48 | input_bboxes_group = [ 49 | np.array([ 50 | [ 0, 0, 10, 10], 51 | [150, 150, 350, 350] 52 | ]), 53 | ] 54 | input_labels_group = [ 55 | np.array([ 56 | 1, 57 | 3 58 | ]), 59 | ] 60 | expected_bboxes_group = input_bboxes_group 61 | expected_labels_group = input_labels_group 62 | 63 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 64 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 65 | 66 | assert('bboxes' in annotations[0]) 67 | assert('labels' in annotations[0]) 68 | np.testing.assert_equal(expected_bboxes_group[0], annotations[0]['bboxes']) 69 | np.testing.assert_equal(expected_labels_group[0], annotations[0]['labels']) 70 | 71 | def test_multiple(self): 72 | input_bboxes_group = [ 73 | np.array([ 74 | [ 0, 0, 10, 10], 75 | [150, 150, 350, 350] 76 | ]), 77 | np.array([ 78 | [0, 0, 50, 50], 79 | ]), 80 | ] 81 | input_labels_group = [ 82 | np.array([ 83 | 1, 84 | 0 85 | ]), 86 | np.array([ 87 | 3 88 | ]) 89 | ] 90 | expected_bboxes_group = input_bboxes_group 91 | expected_labels_group = input_labels_group 92 | 93 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 94 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 95 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 96 | 97 | assert('bboxes' in annotations_group_0[0]) 98 | assert('bboxes' in annotations_group_1[0]) 99 | assert('labels' in annotations_group_0[0]) 100 | assert('labels' in annotations_group_1[0]) 101 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 102 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 103 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 104 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 105 | 106 | 107 | class TestFilterAnnotations(object): 108 | def test_simple_filter(self): 109 | input_bboxes_group = [ 110 | np.array([ 111 | [ 0, 0, 10, 10], 112 | [150, 150, 50, 50] 113 | ]), 114 | ] 115 | input_labels_group = [ 116 | np.array([ 117 | 3, 118 | 1 119 | ]), 120 | ] 121 | 122 | input_image = np.zeros((500, 500, 3)) 123 | 124 | expected_bboxes_group = [ 125 | np.array([ 126 | [0, 0, 10, 10], 127 | ]), 128 | ] 129 | expected_labels_group = [ 130 | np.array([ 131 | 3, 132 | ]), 133 | ] 134 | 135 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 136 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 137 | # expect a UserWarning 138 | with pytest.warns(UserWarning): 139 | image_group, annotations_group = simple_generator.filter_annotations([input_image], annotations, simple_generator.groups[0]) 140 | 141 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group[0]['bboxes']) 142 | np.testing.assert_equal(expected_labels_group[0], annotations_group[0]['labels']) 143 | 144 | def test_multiple_filter(self): 145 | input_bboxes_group = [ 146 | np.array([ 147 | [ 0, 0, 10, 10], 148 | [150, 150, 50, 50], 149 | [150, 150, 350, 350], 150 | [350, 350, 150, 150], 151 | [ 1, 1, 2, 2], 152 | [ 2, 2, 1, 1] 153 | ]), 154 | np.array([ 155 | [0, 0, -1, -1] 156 | ]), 157 | np.array([ 158 | [-10, -10, 0, 0], 159 | [-10, -10, -100, -100], 160 | [ 10, 10, 100, 100] 161 | ]), 162 | np.array([ 163 | [ 10, 10, 100, 100], 164 | [ 10, 10, 600, 600] 165 | ]), 166 | ] 167 | 168 | input_labels_group = [ 169 | np.array([ 170 | 6, 171 | 5, 172 | 4, 173 | 3, 174 | 2, 175 | 1 176 | ]), 177 | np.array([ 178 | 0 179 | ]), 180 | np.array([ 181 | 10, 182 | 11, 183 | 12 184 | ]), 185 | np.array([ 186 | 105, 187 | 107 188 | ]), 189 | ] 190 | 191 | input_image = np.zeros((500, 500, 3)) 192 | 193 | expected_bboxes_group = [ 194 | np.array([ 195 | [ 0, 0, 10, 10], 196 | [150, 150, 350, 350], 197 | [ 1, 1, 2, 2] 198 | ]), 199 | np.zeros((0, 4)), 200 | np.array([ 201 | [10, 10, 100, 100] 202 | ]), 203 | np.array([ 204 | [ 10, 10, 100, 100] 205 | ]), 206 | ] 207 | expected_labels_group = [ 208 | np.array([ 209 | 6, 210 | 4, 211 | 2 212 | ]), 213 | np.zeros((0,)), 214 | np.array([ 215 | 12 216 | ]), 217 | np.array([ 218 | 105 219 | ]), 220 | ] 221 | 222 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 223 | # expect a UserWarning 224 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 225 | with pytest.warns(UserWarning): 226 | image_group, annotations_group_0 = simple_generator.filter_annotations([input_image], annotations_group_0, simple_generator.groups[0]) 227 | 228 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 229 | with pytest.warns(UserWarning): 230 | image_group, annotations_group_1 = simple_generator.filter_annotations([input_image], annotations_group_1, simple_generator.groups[1]) 231 | 232 | annotations_group_2 = simple_generator.load_annotations_group(simple_generator.groups[2]) 233 | with pytest.warns(UserWarning): 234 | image_group, annotations_group_2 = simple_generator.filter_annotations([input_image], annotations_group_2, simple_generator.groups[2]) 235 | 236 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 237 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 238 | 239 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 240 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 241 | 242 | np.testing.assert_equal(expected_bboxes_group[2], annotations_group_2[0]['bboxes']) 243 | np.testing.assert_equal(expected_labels_group[2], annotations_group_2[0]['labels']) 244 | 245 | def test_complete(self): 246 | input_bboxes_group = [ 247 | np.array([ 248 | [ 0, 0, 50, 50], 249 | [150, 150, 50, 50], # invalid bbox 250 | ], dtype=float) 251 | ] 252 | 253 | input_labels_group = [ 254 | np.array([ 255 | 5, # one object of class 5 256 | 3, # one object of class 3 with an invalid box 257 | ], dtype=float) 258 | ] 259 | 260 | input_image = np.zeros((500, 500, 3), dtype=np.uint8) 261 | 262 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6) 263 | # expect a UserWarning 264 | with pytest.warns(UserWarning): 265 | _, [_, labels_batch] = simple_generator[0] 266 | 267 | # test that only object with class 5 is present in labels_batch 268 | labels = np.unique(np.argmax(labels_batch == 5, axis=2)) 269 | assert(len(labels) == 1 and labels[0] == 0), 'Expected only class 0 to be present, but got classes {}'.format(labels) 270 | --------------------------------------------------------------------------------