├── .gitignore ├── LICENCE ├── README.md ├── image ├── example_frame.png └── pr-curve.png ├── mean_average_precision ├── __init__.py ├── ap_accumulator.py ├── detection_map.py ├── example.py ├── pr_curve_example.png └── utils │ ├── __init__.py │ ├── bbox.py │ └── show_frame.py ├── setup.py └── test └── test_evaluate.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | 3 | # Created by .ignore support plugin (hsz.mobi) 4 | ### Python template 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Mathieu Garon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detection mAP 2 | 3 | A simple utility tool to evaluate Bounding box classification task following Pascal VOC [paper](http://homepages.inf.ed.ac.uk/ckiw/postscript/ijcv_voc09.pdf). 4 | 5 | To learn about this metric I recommend this excellent blog post by Sancho McCann before reading the paper : [link](https://sanchom.wordpress.com/2011/09/01/precision-recall) 6 | 7 | **Note that the method is not compared with the original VOC implementation! (See Todo)** 8 | 9 | ## features 10 | - Simple : numpy and matplotlib are the only dependencies 11 | - Compute a running evaluation : input prediction/ground truth at each frames, no need to save in files 12 | - Plot (matplotlib) per class pr-curves with interpolated average precision (default) or average precision 13 | 14 | ## Method 15 | ### Multiclass mAP 16 | Handle every class as one against the others. (x against z) 17 | - True positive (TP): 18 | - Gt x predicted as x 19 | - False positive (FP): 20 | - Prediction x if Gt x has already a TP prediction 21 | - Prediction x not overlapping any Gt x 22 | - False negative (FN): 23 | - Gt x not predicted as x 24 | ### Example frame 25 | ![example](https://github.com/MathGaron/mean_average_precision/raw/master/image/example_frame.png "example frame") 26 | 27 | ## Code 28 | All you need is your predicted bounding boxes with class and confidence score and the ground truth bounding boxes with their classes. 29 | 30 | ### [main loop](https://github.com/MathGaron/mean_average_precision/blob/master/mean_average_precision/example.py) : 31 | ```python 32 | frames = [(pred_bb1, pred_cls1, pred_conf1, gt_bb1, gt_cls1), 33 | (pred_bb2, pred_cls2, pred_conf2, gt_bb2, gt_cls2), 34 | (pred_bb3, pred_cls3, pred_conf3, gt_bb3, gt_cls3)] 35 | n_class = 7 36 | 37 | mAP = DetectionMAP(n_class) 38 | for frame in frames: 39 | mAP.evaluate(*frame) 40 | 41 | mAP.plot() 42 | plt.show() # or plt.savefig(path) 43 | ``` 44 | In this example a frame is a tuple containing: 45 | - Predicted bounding boxes : numpy array [n, 4] 46 | - Predicted classes: numpy array [n] 47 | - Predicted confidences: numpy array [n] 48 | - Ground truth bounding boxes:numpy array [m, 4] 49 | - Ground truth classes: numpy array [m] 50 | 51 | Note that the bounding boxes are represented as two corners points : [x1, y1, x2, y2] 52 | 53 | ![example](https://github.com/MathGaron/mean_average_precision/raw/master/image/pr-curve.png "pr-curves") 54 | 55 | ### TODO 56 | - ~~Interpolated average precision~~ 57 | - Test against [VOC matlab implementation](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html) 58 | 59 | ### Contribution 60 | And of course any bugfixes/contribution are always welcome! 61 | -------------------------------------------------------------------------------- /image/example_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathGaron/mean_average_precision/2be3b1923251e734b4400a7a9115620260c759af/image/example_frame.png -------------------------------------------------------------------------------- /image/pr-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathGaron/mean_average_precision/2be3b1923251e734b4400a7a9115620260c759af/image/pr-curve.png -------------------------------------------------------------------------------- /mean_average_precision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathGaron/mean_average_precision/2be3b1923251e734b4400a7a9115620260c759af/mean_average_precision/__init__.py -------------------------------------------------------------------------------- /mean_average_precision/ap_accumulator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple accumulator class that keeps track of True positive, False positive and False negative 3 | to compute precision and recall of a certain class 4 | """ 5 | 6 | 7 | class APAccumulator: 8 | def __init__(self): 9 | self.TP, self.FP, self.FN = 0, 0, 0 10 | 11 | def inc_good_prediction(self, value=1): 12 | self.TP += value 13 | 14 | def inc_bad_prediction(self, value=1): 15 | self.FP += value 16 | 17 | def inc_not_predicted(self, value=1): 18 | self.FN += value 19 | 20 | @property 21 | def precision(self): 22 | total_predicted = self.TP + self.FP 23 | if total_predicted == 0: 24 | total_gt = self.TP + self.FN 25 | if total_gt == 0: 26 | return 1. 27 | else: 28 | return 0. 29 | return float(self.TP) / total_predicted 30 | 31 | @property 32 | def recall(self): 33 | total_gt = self.TP + self.FN 34 | if total_gt == 0: 35 | return 1. 36 | return float(self.TP) / total_gt 37 | 38 | def __str__(self): 39 | str = "" 40 | str += "True positives : {}\n".format(self.TP) 41 | str += "False positives : {}\n".format(self.FP) 42 | str += "False Negatives : {}\n".format(self.FN) 43 | str += "Precision : {}\n".format(self.precision) 44 | str += "Recall : {}\n".format(self.recall) 45 | return str -------------------------------------------------------------------------------- /mean_average_precision/detection_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mean_average_precision.ap_accumulator import APAccumulator 3 | from mean_average_precision.utils.bbox import jaccard 4 | import math 5 | import matplotlib.pyplot as plt 6 | 7 | DEBUG = False 8 | 9 | 10 | class DetectionMAP: 11 | def __init__(self, n_class, pr_samples=11, overlap_threshold=0.5): 12 | """ 13 | Running computation of average precision of n_class in a bounding box + classification task 14 | :param n_class: quantity of class 15 | :param pr_samples: quantification of threshold for pr curve 16 | :param overlap_threshold: minimum overlap threshold 17 | """ 18 | self.n_class = n_class 19 | self.overlap_threshold = overlap_threshold 20 | self.pr_scale = np.linspace(0, 1, pr_samples) 21 | self.total_accumulators = [] 22 | self.reset_accumulators() 23 | 24 | def reset_accumulators(self): 25 | """ 26 | Reset the accumulators state 27 | TODO this is hard to follow... should use a better data structure 28 | total_accumulators : list of list of accumulators at each pr_scale for each class 29 | :return: 30 | """ 31 | self.total_accumulators = [] 32 | for i in range(len(self.pr_scale)): 33 | class_accumulators = [] 34 | for j in range(self.n_class): 35 | class_accumulators.append(APAccumulator()) 36 | self.total_accumulators.append(class_accumulators) 37 | 38 | def evaluate(self, pred_bb, pred_classes, pred_conf, gt_bb, gt_classes): 39 | """ 40 | Update the accumulator for the running mAP evaluation. 41 | For exemple, this can be called for each images 42 | :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4] 43 | :param pred_classes: (np.array) Predicted Classes : Shape [n_pred] 44 | :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred] 45 | :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4] 46 | :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt] 47 | :return: 48 | """ 49 | 50 | if pred_bb.ndim == 1: 51 | pred_bb = np.repeat(pred_bb[:, np.newaxis], 4, axis=1) 52 | IoUmask = None 53 | if len(pred_bb) > 0: 54 | IoUmask = self.compute_IoU_mask(pred_bb, gt_bb, self.overlap_threshold) 55 | for accumulators, r in zip(self.total_accumulators, self.pr_scale): 56 | if DEBUG: 57 | print("Evaluate pr_scale {}".format(r)) 58 | self.evaluate_(IoUmask, accumulators, pred_classes, pred_conf, gt_classes, r) 59 | 60 | @staticmethod 61 | def evaluate_(IoUmask, accumulators, pred_classes, pred_conf, gt_classes, confidence_threshold): 62 | pred_classes = pred_classes.astype(np.int) 63 | gt_classes = gt_classes.astype(np.int) 64 | 65 | for i, acc in enumerate(accumulators): 66 | gt_number = np.sum(gt_classes == i) 67 | pred_mask = np.logical_and(pred_classes == i, pred_conf >= confidence_threshold) 68 | pred_number = np.sum(pred_mask) 69 | if pred_number == 0: 70 | acc.inc_not_predicted(gt_number) 71 | continue 72 | 73 | IoU1 = IoUmask[pred_mask, :] 74 | mask = IoU1[:, gt_classes == i] 75 | 76 | tp = DetectionMAP.compute_true_positive(mask) 77 | fp = pred_number - tp 78 | fn = gt_number - tp 79 | acc.inc_good_prediction(tp) 80 | acc.inc_not_predicted(fn) 81 | acc.inc_bad_prediction(fp) 82 | 83 | @staticmethod 84 | def compute_IoU_mask(prediction, gt, overlap_threshold): 85 | IoU = jaccard(prediction, gt) 86 | # for each prediction select gt with the largest IoU and ignore the others 87 | for i in range(len(prediction)): 88 | maxj = IoU[i, :].argmax() 89 | IoU[i, :maxj] = 0 90 | IoU[i, (maxj + 1):] = 0 91 | # make a mask of all "matched" predictions vs gt 92 | return IoU >= overlap_threshold 93 | 94 | @staticmethod 95 | def compute_true_positive(mask): 96 | # sum all gt with prediction of its class 97 | return np.sum(mask.any(axis=0)) 98 | 99 | def compute_ap(self, precisions, recalls): 100 | """ 101 | Compute average precision of a particular classes (cls_idx) 102 | :param cls: 103 | :return: 104 | """ 105 | previous_recall = 0 106 | average_precision = 0 107 | for precision, recall in zip(precisions[::-1], recalls[::-1]): 108 | average_precision += precision * (recall - previous_recall) 109 | previous_recall = recall 110 | return average_precision 111 | 112 | def compute_precision_recall_(self, class_index, interpolated=True): 113 | precisions = [] 114 | recalls = [] 115 | for acc in self.total_accumulators: 116 | precisions.append(acc[class_index].precision) 117 | recalls.append(acc[class_index].recall) 118 | 119 | if interpolated: 120 | interpolated_precision = [] 121 | for precision in precisions: 122 | last_max = 0 123 | if interpolated_precision: 124 | last_max = max(interpolated_precision) 125 | interpolated_precision.append(max(precision, last_max)) 126 | precisions = interpolated_precision 127 | return precisions, recalls 128 | 129 | def plot_pr(self, ax, class_name, precisions, recalls, average_precision): 130 | ax.step(recalls, precisions, color='b', alpha=0.2, 131 | where='post') 132 | ax.fill_between(recalls, precisions, step='post', alpha=0.2, 133 | color='b') 134 | ax.set_ylim([0.0, 1.05]) 135 | ax.set_xlim([0.0, 1.0]) 136 | ax.set_xlabel('Recall') 137 | ax.set_ylabel('Precision') 138 | ax.set_title('{0:} : AUC={1:0.2f}'.format(class_name, average_precision)) 139 | 140 | def plot(self, interpolated=True, class_names=None): 141 | """ 142 | Plot all pr-curves for each classes 143 | :param interpolated: will compute the interpolated curve 144 | :return: 145 | """ 146 | grid = int(math.ceil(math.sqrt(self.n_class))) 147 | fig, axes = plt.subplots(nrows=grid, ncols=grid) 148 | mean_average_precision = [] 149 | # TODO: data structure not optimal for this operation... 150 | for i, ax in enumerate(axes.flat): 151 | if i > self.n_class - 1: 152 | break 153 | precisions, recalls = self.compute_precision_recall_(i, interpolated) 154 | average_precision = self.compute_ap(precisions, recalls) 155 | class_name = class_names[i] if class_names else "Class {}".format(i) 156 | self.plot_pr(ax, class_name, precisions, recalls, average_precision) 157 | mean_average_precision.append(average_precision) 158 | 159 | plt.suptitle("Mean average precision : {:0.2f}".format(sum(mean_average_precision)/len(mean_average_precision))) 160 | fig.tight_layout() 161 | -------------------------------------------------------------------------------- /mean_average_precision/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple Usage example (with 3 images) 3 | """ 4 | from mean_average_precision.detection_map import DetectionMAP 5 | from mean_average_precision.utils.show_frame import show_frame 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | pred_bb1 = np.array([[0.880688, 0.44609185, 0.95696718, 0.6476958], 10 | [0.84020283, 0.45787981, 0.99351478, 0.64294884], 11 | [0.78723741, 0.61799151, 0.9083041, 0.75623035], 12 | [0.22078986, 0.30151826, 0.36679274, 0.40551913], 13 | [0.0041579, 0.48359361, 0.06867643, 0.60145104], 14 | [0.4731401, 0.33888632, 0.75164948, 0.80546954], 15 | [0.75489414, 0.75228018, 0.87922037, 0.88110524], 16 | [0.21953127, 0.77934921, 0.34853417, 0.90626764], 17 | [0.81, 0.11, 0.91, 0.21]]) 18 | pred_cls1 = np.array([0, 0, 0, 1, 1, 2, 2, 2, 3]) 19 | pred_conf1 = np.array([0.95, 0.75, 0.4, 0.3, 1, 1, 0.75, 0.5, 0.8]) 20 | gt_bb1 = np.array([[0.86132812, 0.48242188, 0.97460938, 0.6171875], 21 | [0.18554688, 0.234375, 0.36132812, 0.41601562], 22 | [0., 0.47265625, 0.0703125, 0.62109375], 23 | [0.47070312, 0.3125, 0.77929688, 0.78125], 24 | [0.8, 0.1, 0.9, 0.2]]) 25 | gt_cls1 = np.array([0, 0, 1, 2, 3]) 26 | 27 | pred_bb2 = np.array([[0.6, 0.4, 0.8, 0.6], 28 | [0.45, 0.24, 0.55179688, 0.35179688], 29 | [0.2, 0.15, 0.29, 0.30], 30 | [0.95, 0.55, 0.99, 0.66670889], 31 | [0.62373358, 0.43393397, 0.82830238, 0.68219709], 32 | [0.8814062, 0.8921875, 0.94453125, 0.9704688], 33 | [0.8514062, 0.9121875, 0.99453125, 0.9804688], 34 | [0.40, 0.44, 0.55, 0.56], 35 | [0.1672115, 0.435711, 0.32729435, 0.57853043], 36 | [0.18287398, 0.15450388, 0.27082703, 0.31132805], 37 | [0.3713485, 0.24020095, 0.62879527, 0.48929602]]) 38 | pred_cls2 = np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2]) 39 | pred_conf2 = np.array([0.75, 0.78, 0.83, 0.42, 0.2457653, 40 | 0.95, 0.5, 0.81003532, 0.18837614, 0.77496605, 0.27333026]) 41 | gt_bb2 = np.array([[0.625, 0.43554688, 0.828125, 0.67382812], 42 | [0.45898438, 0.25390625, 0.59179688, 0.34179688], 43 | [0.18164062, 0.16015625, 0.28125, 0.31054688], 44 | [0.8914062, 0.8821875, 0.95453125, 0.9804688], 45 | [0.40234375, 0.44921875, 0.55078125, 0.56445312], 46 | [0.16796875, 0.43554688, 0.328125, 0.578125]]) 47 | gt_cls2 = np.array([0, 0, 0, 1, 2, 2]) 48 | 49 | pred_bb3 = np.array([[0.74, 0.58, 1.0, 0.83], 50 | [0.75, 0.575, 0.99, 0.83], 51 | [0.57, 0.23, 1.0, 0.62], 52 | [0.59, 0.24, 1.0, 0.63], 53 | [0.55, 0.24, 0.33, 0.7], 54 | [0.12, 0.21, 0.31, 0.39], 55 | [0.1240625, 0.2109375, 0.859375, 0.39453125], 56 | [2.86702722e-01, 5.87677717e-01, 3.90843153e-01, 7.14454949e-01], 57 | [2.87590116e-01, 8.76132399e-02, 3.79709303e-01, 2.05121845e-01]]) 58 | pred_cls3 = np.array( 59 | [0, 0, 0, 0, 0, 1, 1, 2, 2]) 60 | pred_conf3 = np.array([0.75, 0.90, 0.9, 0.9, 0.5, 0.84, 61 | 0.1, 0.2363426, 0.02707205]) 62 | gt_bb3 = np.array([[0.74609375, 0.58007812, 1.05273438, 0.83007812], 63 | [0.57226562, 0.234375, 1.14453125, 0.62890625], 64 | [0.1240625, 0.2109375, 0.329375, 0.39453125]]) 65 | gt_cls3 = np.array([0, 0, 1]) 66 | 67 | if __name__ == '__main__': 68 | frames = [(pred_bb1, pred_cls1, pred_conf1, gt_bb1, gt_cls1), 69 | (pred_bb2, pred_cls2, pred_conf2, gt_bb2, gt_cls2), 70 | (pred_bb3, pred_cls3, pred_conf3, gt_bb3, gt_cls3)] 71 | n_class = 4 72 | 73 | mAP = DetectionMAP(n_class) 74 | for i, frame in enumerate(frames): 75 | print("Evaluate frame {}".format(i)) 76 | show_frame(*frame) 77 | mAP.evaluate(*frame) 78 | 79 | mAP.plot() 80 | plt.show() 81 | #plt.savefig("pr_curve_example.png") 82 | -------------------------------------------------------------------------------- /mean_average_precision/pr_curve_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathGaron/mean_average_precision/2be3b1923251e734b4400a7a9115620260c759af/mean_average_precision/pr_curve_example.png -------------------------------------------------------------------------------- /mean_average_precision/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathGaron/mean_average_precision/2be3b1923251e734b4400a7a9115620260c759af/mean_average_precision/utils/__init__.py -------------------------------------------------------------------------------- /mean_average_precision/utils/bbox.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Bounding box intersection over union calculation. 4 | Borrowed from pytorch SSD implementation : https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py 5 | and adapted to numpy. 6 | """ 7 | import numpy as np 8 | 9 | 10 | def intersect_area(box_a, box_b): 11 | """ 12 | Compute the area of intersection between two rectangular bounding box 13 | Bounding boxes use corner notation : [x1, y1, x2, y2] 14 | Args: 15 | box_a: (np.array) bounding boxes, Shape: [A,4]. 16 | box_b: (np.array) bounding boxes, Shape: [B,4]. 17 | Return: 18 | np.array intersection area, Shape: [A,B]. 19 | """ 20 | resized_A = box_a[:, np.newaxis, :] 21 | resized_B = box_b[np.newaxis, :, :] 22 | max_xy = np.minimum(resized_A[:, :, 2:], resized_B[:, :, 2:]) 23 | min_xy = np.maximum(resized_A[:, :, :2], resized_B[:, :, :2]) 24 | 25 | diff_xy = (max_xy - min_xy) 26 | inter = np.clip(diff_xy, a_min=0, a_max=np.max(diff_xy)) 27 | return inter[:, :, 0] * inter[:, :, 1] 28 | 29 | 30 | def jaccard(box_a, box_b): 31 | """ 32 | Compute the jaccard overlap of two sets of boxes. The jaccard overlap 33 | is simply the intersection over union of two boxes. Here we operate on 34 | ground truth boxes and default boxes. 35 | E.g.: 36 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 37 | Args: 38 | box_a: (np.array) Predicted bounding boxes, Shape: [n_pred, 4] 39 | box_b: (np.array) Ground Truth bounding boxes, Shape: [n_gt, 4] 40 | Return: 41 | jaccard overlap: (np.array) Shape: [n_pred, n_gt] 42 | """ 43 | inter = intersect_area(box_a, box_b) 44 | area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])) 45 | area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])) 46 | area_a = area_a[:, np.newaxis] 47 | area_b = area_b[np.newaxis, :] 48 | union = area_a + area_b - inter 49 | return inter / union -------------------------------------------------------------------------------- /mean_average_precision/utils/show_frame.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.patches as patches 3 | import numpy as np 4 | 5 | 6 | def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, background=np.zeros((500, 500, 3)), show_confidence=True): 7 | """ 8 | Plot the boundingboxes 9 | :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4] 10 | :param pred_classes: (np.array) Predicted Classes : Shape [n_pred] 11 | :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred] 12 | :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4] 13 | :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt] 14 | :return: 15 | """ 16 | n_pred = pred_bb.shape[0] 17 | n_gt = gt_bb.shape[0] 18 | n_class = np.max(np.append(pred_classes, gt_classes)) + 1 19 | h, w, c = background.shape 20 | 21 | ax = plt.subplot("111") 22 | ax.imshow(background) 23 | cmap = plt.cm.get_cmap('hsv') 24 | 25 | confidence_alpha = pred_conf.copy() 26 | if not show_confidence: 27 | confidence_alpha.fill(1) 28 | 29 | for i in range(n_pred): 30 | x1 = pred_bb[i, 0] * w 31 | y1 = pred_bb[i, 1] * h 32 | x2 = pred_bb[i, 2] * w 33 | y2 = pred_bb[i, 3] * h 34 | rect_w = x2 - x1 35 | rect_h = y2 - y1 36 | print(x1, y1) 37 | ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h, 38 | fill=False, 39 | edgecolor=cmap(float(pred_classes[i]) / n_class), 40 | linestyle='dashdot', 41 | alpha=confidence_alpha[i])) 42 | 43 | for i in range(n_gt): 44 | x1 = gt_bb[i, 0] * w 45 | y1 = gt_bb[i, 1] * h 46 | x2 = gt_bb[i, 2] * w 47 | y2 = gt_bb[i, 3] * h 48 | rect_w = x2 - x1 49 | rect_h = y2 - y1 50 | ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h, 51 | fill=False, 52 | edgecolor=cmap(float(gt_classes[i]) / n_class))) 53 | 54 | legend_handles = [] 55 | for i in range(n_class): 56 | legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label="class : {}".format(i))) 57 | ax.legend(handles=legend_handles) 58 | plt.show() 59 | 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | 5 | setup( 6 | name='mean_average_precision', 7 | version='0.1', 8 | packages=find_packages(), 9 | url='', 10 | license='MIT', 11 | author='Mathieu Garon', 12 | author_email='mathieugaron91@gmail.com', 13 | description='' 14 | ) 15 | -------------------------------------------------------------------------------- /test/test_evaluate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from mean_average_precision.detection_map import DetectionMAP 5 | from mean_average_precision.utils.bbox import jaccard 6 | from mean_average_precision.utils.show_frame import show_frame 7 | 8 | 9 | class TestNumpyFunctions(unittest.TestCase): 10 | def setUp(self): 11 | self.pred = np.array([[0, 0.55, 0.25, 1], 12 | [0, 0.65, 0.25, 1.0], 13 | [0.5, 0.8, 0.6, 1.0], 14 | [0.75, 0, 1.0, 0.45], 15 | [0.75, 0.85, 1.0, 1.0], 16 | [0, 0, 0.09, 0.09]]) 17 | self.cls = np.array([0, 0, 0, 0, 1, 2]) 18 | self.conf = np.array([1, 0.7, 0.5, 1, 0.75, 1]) 19 | self.gt = np.array([[0, 0.6, 0.2, 1.0], 20 | [0.8, 0.9, 1.0, 1.0], 21 | [0.7, 0, 1, 0.5], 22 | [0, 0, 0.1, 0.1]]) 23 | self.gt_cls = np.array([0, 0, 1, 2]) 24 | 25 | #show_frame(self.pred, self.cls, self.conf, self.gt, self.gt_cls) 26 | 27 | def tearDown(self): 28 | pass 29 | 30 | def test_is_iou_thresholded(self): 31 | IoU = DetectionMAP.compute_IoU_mask(self.pred, self.gt, 0.7) 32 | valid_IoU = np.argwhere(IoU) 33 | np.testing.assert_equal(valid_IoU, np.array([[0, 0], [1, 0], [3, 2], [5, 3]])) 34 | 35 | def test_is_FN_incremented_properly(self): 36 | mAP = DetectionMAP(3) 37 | mAP.evaluate(self.pred, self.cls, self.conf, self.gt, self.gt_cls) 38 | 39 | self.assertEqual(mAP.total_accumulators[0][0].FN, 1) 40 | self.assertEqual(mAP.total_accumulators[0][1].FN, 1) 41 | self.assertEqual(mAP.total_accumulators[0][2].FN, 0) 42 | 43 | def test_is_FN_incremented_properly_if_no_prediction(self): 44 | mAP = DetectionMAP(3) 45 | mAP.evaluate(np.array([]), np.array([]), np.array([]), self.gt, self.gt_cls) 46 | self.assertEqual(mAP.total_accumulators[0][0].FN, 2) 47 | self.assertEqual(mAP.total_accumulators[0][1].FN, 1) 48 | self.assertEqual(mAP.total_accumulators[0][2].FN, 1) 49 | 50 | def test_is_TP_incremented_properly(self): 51 | mAP = DetectionMAP(3) 52 | mAP.evaluate(self.pred, self.cls, self.conf, self.gt, self.gt_cls) 53 | 54 | self.assertEqual(mAP.total_accumulators[0][0].TP, 1) 55 | self.assertEqual(mAP.total_accumulators[0][1].TP, 0) 56 | self.assertEqual(mAP.total_accumulators[0][2].TP, 1) 57 | 58 | def test_is_FP_incremented_properly_when_away_from_gt(self): 59 | mAP = DetectionMAP(3) 60 | mAP.evaluate(self.pred, self.cls, self.conf, self.gt, self.gt_cls) 61 | 62 | self.assertEqual(mAP.total_accumulators[0][0].FP, 3) 63 | self.assertEqual(mAP.total_accumulators[0][1].FP, 1) 64 | self.assertEqual(mAP.total_accumulators[0][2].FP, 0) 65 | 66 | --------------------------------------------------------------------------------