diff --git a/ML/Pytorch/object_detection/metrics/mean_avg_precision.py b/ML/Pytorch/object_detection/metrics/mean_avg_precision.py index 73d1e65f..aa7efe95 100644 --- a/ML/Pytorch/object_detection/metrics/mean_avg_precision.py +++ b/ML/Pytorch/object_detection/metrics/mean_avg_precision.py @@ -74,6 +74,7 @@ def mean_average_precision( num_gts = len(ground_truth_img) best_iou = 0 + best_gt_idx = None for idx, gt in enumerate(ground_truth_img): iou = intersection_over_union( @@ -86,7 +87,7 @@ def mean_average_precision( best_iou = iou best_gt_idx = idx - if best_iou > iou_threshold: + if best_gt_idx is not None and best_iou > iou_threshold: # only detect ground truth detection once if amount_bboxes[detection[0]][best_gt_idx] == 0: # true positive and add this bounding box to seen