diff --git a/README.md b/README.md index 00debfaf8..f97c7035c 100644 --- a/README.md +++ b/README.md @@ -50,18 +50,18 @@ retinanet-train coco /path/to/MS/COCO The pretrained MS COCO model can be downloaded [here](https://github.com/fizyr/keras-retinanet/releases/download/0.1/resnet50_coco_best_v1.2.2.h5). Results using the `cocoapi` are shown below (note: according to the paper, this configuration should achieve a mAP of 0.343). ``` - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.325 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.513 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.342 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.149 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.354 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.345 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.533 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.368 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.189 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.380 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.465 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.288 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.437 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.263 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.510 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.623 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.301 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.482 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.529 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.565 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.666 ``` For training on [OID](https://github.com/openimages/dataset), run: @@ -118,7 +118,7 @@ from keras_retinanet.models.resnet import custom_objects model = keras.models.load_model('/path/to/model.h5', custom_objects=custom_objects) ``` -Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x600x3`. +Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x800x3`. ## CSV datasets The `CSVGenerator` provides an easy way to define your own datasets. diff --git a/keras_retinanet/backend/tensorflow_backend.py b/keras_retinanet/backend/tensorflow_backend.py index d2bd1afe4..522f13a90 100644 --- a/keras_retinanet/backend/tensorflow_backend.py +++ b/keras_retinanet/backend/tensorflow_backend.py @@ -18,10 +18,6 @@ import keras -def top_k(*args, **kwargs): - return tensorflow.nn.top_k(*args, **kwargs) - - def resize_images(*args, **kwargs): return tensorflow.image.resize_images(*args, **kwargs) @@ -34,6 +30,10 @@ def range(*args, **kwargs): return tensorflow.range(*args, **kwargs) +def scatter_nd(*args, **kwargs): + return tensorflow.scatter_nd(*args, **kwargs) + + def gather_nd(*args, **kwargs): return tensorflow.gather_nd(*args, **kwargs) diff --git a/keras_retinanet/bin/evaluate_coco.py b/keras_retinanet/bin/evaluate_coco.py index fcc88efa0..374c06eb4 100755 --- a/keras_retinanet/bin/evaluate_coco.py +++ b/keras_retinanet/bin/evaluate_coco.py @@ -75,7 +75,7 @@ def main(args=None): # create a generator for testing data test_generator = CocoGenerator( args.coco_path, - args.set, + args.set ) evaluate_coco(test_generator, model, args.score_threshold) diff --git a/keras_retinanet/layers/_misc.py b/keras_retinanet/layers/_misc.py index 4cb969e20..3fc409f79 100644 --- a/keras_retinanet/layers/_misc.py +++ b/keras_retinanet/layers/_misc.py @@ -76,43 +76,60 @@ def get_config(self): class NonMaximumSuppression(keras.layers.Layer): - def __init__(self, nms_threshold=0.5, top_k=None, max_boxes=300, *args, **kwargs): - self.nms_threshold = nms_threshold - self.top_k = top_k - self.max_boxes = max_boxes + def __init__(self, nms_threshold=0.5, score_threshold=0.05, max_boxes=300, *args, **kwargs): + self.nms_threshold = nms_threshold + self.score_threshold = score_threshold + self.max_boxes = max_boxes super(NonMaximumSuppression, self).__init__(*args, **kwargs) def call(self, inputs, **kwargs): - boxes, classification, detections = inputs - # TODO: support batch size > 1. - boxes = boxes[0] - classification = classification[0] - detections = detections[0] + boxes = inputs[0][0] + classification = inputs[1][0] + other = [i[0] for i in inputs[2:]] # can be any user-specified additional data + indices = backend.range(keras.backend.shape(classification)[0]) + selected_scores = [] + + # perform per class NMS + for c in range(int(classification.shape[1])): + scores = classification[:, c] + + # threshold based on score + score_indices = backend.where(keras.backend.greater(scores, self.score_threshold)) + score_indices = keras.backend.cast(score_indices, 'int32') + boxes_ = backend.gather_nd(boxes, score_indices) + scores = keras.backend.gather(scores, score_indices)[:, 0] + + # perform NMS + nms_indices = backend.non_max_suppression(boxes_, scores, max_output_size=self.max_boxes, iou_threshold=self.nms_threshold) + + # filter set of original indices + selected_indices = keras.backend.gather(score_indices, nms_indices) + + # mask original classification column, setting all suppressed values to 0 + scores = keras.backend.gather(scores, nms_indices) + scores = backend.scatter_nd(selected_indices, scores, keras.backend.shape(classification[:, c])) + scores = keras.backend.expand_dims(scores, axis=1) - scores = keras.backend.max(classification, axis=1) + selected_scores.append(scores) - # selecting best anchors theoretically improves speed at the cost of minor performance - if self.top_k: - scores, indices = backend.top_k(scores, self.top_k, sorted=False) - boxes = keras.backend.gather(boxes, indices) - classification = keras.backend.gather(classification, indices) - detections = keras.backend.gather(detections, indices) + # reconstruct the (suppressed) classification scores + classification = keras.backend.concatenate(selected_scores, axis=1) - indices = backend.non_max_suppression(boxes, scores, max_output_size=self.max_boxes, iou_threshold=self.nms_threshold) + # reconstruct into the expected output + detections = keras.backend.concatenate([boxes, classification] + other, axis=1) - detections = keras.backend.gather(detections, indices) return keras.backend.expand_dims(detections, axis=0) def compute_output_shape(self, input_shape): - return (input_shape[2][0], None, input_shape[2][2]) + return (input_shape[0][0], input_shape[0][1], sum([i[2] for i in input_shape])) def get_config(self): config = super(NonMaximumSuppression, self).get_config() config.update({ - 'nms_threshold' : self.nms_threshold, - 'top_k' : self.top_k, - 'max_boxes' : self.max_boxes, + 'nms_threshold' : self.nms_threshold, + 'score_threshold' : self.score_threshold, + 'max_boxes' : self.max_boxes, }) return config diff --git a/keras_retinanet/models/retinanet.py b/keras_retinanet/models/retinanet.py index 33252dc28..98bc6d194 100644 --- a/keras_retinanet/models/retinanet.py +++ b/keras_retinanet/models/retinanet.py @@ -103,14 +103,15 @@ def default_regression_model(num_anchors, pyramid_feature_size=256, regression_f def __create_pyramid_features(C3, C4, C5, feature_size=256): # upsample C5 to get P5 from the FPN paper - P5 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='P5')(C5) + P5 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C5_reduced')(C5) P5_upsampled = layers.UpsampleLike(name='P5_upsampled')([P5, C4]) + P5 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P5')(P5) # add P5 elementwise to C4 P4 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C4_reduced')(C4) P4 = keras.layers.Add(name='P4_merged')([P5_upsampled, P4]) - P4 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P4')(P4) P4_upsampled = layers.UpsampleLike(name='P4_upsampled')([P4, C3]) + P4 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P4')(P4) # add P4 elementwise to C3 P3 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C3_reduced')(C3) @@ -207,12 +208,13 @@ def retinanet_bbox(inputs, num_classes, nms=True, name='retinanet-bbox', *args, classification = model.outputs[2] # apply predicted regression to anchors - boxes = layers.RegressBoxes(name='boxes')([anchors, regression]) - detections = keras.layers.Concatenate(axis=2)([boxes, classification] + model.outputs[3:]) + boxes = layers.RegressBoxes(name='boxes')([anchors, regression]) # additionally apply non maximum suppression if nms: - detections = layers.NonMaximumSuppression(name='nms')([boxes, classification, detections]) + detections = layers.NonMaximumSuppression(name='nms')([boxes, classification] + model.outputs[3:]) + else: + detections = keras.layers.Concatenate(axis=2)([boxes, classification] + model.outputs[3:]) # construct the model return keras.models.Model(inputs=inputs, outputs=model.outputs[1:] + [detections], name=name) diff --git a/keras_retinanet/preprocessing/generator.py b/keras_retinanet/preprocessing/generator.py index 6916e53d0..ed37ddf25 100644 --- a/keras_retinanet/preprocessing/generator.py +++ b/keras_retinanet/preprocessing/generator.py @@ -40,8 +40,8 @@ def __init__( batch_size=1, group_method='ratio', # one of 'none', 'random', 'ratio' shuffle_groups=True, - image_min_side=600, - image_max_side=1024, + image_min_side=800, + image_max_side=1333, transform_parameters=None, ): self.transform_generator = transform_generator diff --git a/keras_retinanet/utils/anchors.py b/keras_retinanet/utils/anchors.py index b7150bdb8..620e53b2c 100644 --- a/keras_retinanet/utils/anchors.py +++ b/keras_retinanet/utils/anchors.py @@ -171,16 +171,20 @@ def bbox_transform(anchors, gt_boxes, mean=None, std=None): elif not isinstance(std, np.ndarray): raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) - anchor_widths = anchors[:, 2] - anchors[:, 0] + 1.0 - anchor_heights = anchors[:, 3] - anchors[:, 1] + 1.0 + anchor_widths = anchors[:, 2] - anchors[:, 0] + anchor_heights = anchors[:, 3] - anchors[:, 1] anchor_ctr_x = anchors[:, 0] + 0.5 * anchor_widths anchor_ctr_y = anchors[:, 1] + 0.5 * anchor_heights - gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + 1.0 - gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + 1.0 + gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights + # clip widths to 1 + gt_widths = np.maximum(gt_widths, 1) + gt_heights = np.maximum(gt_heights, 1) + targets_dx = (gt_ctr_x - anchor_ctr_x) / anchor_widths targets_dy = (gt_ctr_y - anchor_ctr_y) / anchor_heights targets_dw = np.log(gt_widths / anchor_widths) @@ -204,15 +208,15 @@ def compute_overlap(a, b): ------- overlaps: (N, K) ndarray of overlap between boxes and query_boxes """ - area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1) + area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) - iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0]) + 1 - ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1]) + 1 + iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0]) + ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1]) iw = np.maximum(iw, 0) ih = np.maximum(ih, 0) - ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - iw * ih + ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih ua = np.maximum(ua, np.finfo(float).eps) diff --git a/keras_retinanet/utils/coco_eval.py b/keras_retinanet/utils/coco_eval.py index d7a77a6b8..72b3062c1 100644 --- a/keras_retinanet/utils/coco_eval.py +++ b/keras_retinanet/utils/coco_eval.py @@ -28,8 +28,8 @@ def evaluate_coco(generator, model, threshold=0.05): # start collecting results results = [] image_ids = [] - for i in range(generator.size()): - image = generator.load_image(i) + for index in range(generator.size()): + image = generator.load_image(index) image = generator.preprocess_image(image) image, scale = generator.resize_image(image) @@ -50,26 +50,23 @@ def evaluate_coco(generator, model, threshold=0.05): detections[:, :, 3] -= detections[:, :, 1] # compute predicted labels and scores - for detection in detections[0, ...]: - positive_labels = np.where(detection[4:] > threshold)[0] - + for i, j in np.transpose(np.where(detections[0, :, 4:] > threshold)): # append detections for each positively labeled class - for label in positive_labels: - image_result = { - 'image_id' : generator.image_ids[i], - 'category_id' : generator.label_to_coco_label(label), - 'score' : float(detection[4 + label]), - 'bbox' : (detection[:4]).tolist(), - } + image_result = { + 'image_id' : generator.image_ids[index], + 'category_id' : generator.label_to_coco_label(j), + 'score' : float(detections[0, i, 4 + j]), + 'bbox' : (detections[0, i, :4]).tolist(), + } - # append detection to results - results.append(image_result) + # append detection to results + results.append(image_result) # append image to list of processed images - image_ids.append(generator.image_ids[i]) + image_ids.append(generator.image_ids[index]) # print progress - print('{}/{}'.format(i, generator.size()), end='\r') + print('{}/{}'.format(index, generator.size()), end='\r') if not len(results): return diff --git a/keras_retinanet/utils/image.py b/keras_retinanet/utils/image.py index a50490334..847a97ade 100644 --- a/keras_retinanet/utils/image.py +++ b/keras_retinanet/utils/image.py @@ -160,7 +160,7 @@ def apply_transform(matrix, image, params): return output -def resize_image(img, min_side=600, max_side=1024): +def resize_image(img, min_side=800, max_side=1333): (rows, cols, _) = img.shape smallest_side = min(rows, cols) diff --git a/tests/layers/test_misc.py b/tests/layers/test_misc.py index 0f4555473..c0d12ab00 100644 --- a/tests/layers/test_misc.py +++ b/tests/layers/test_misc.py @@ -100,18 +100,19 @@ def test_simple(self): ]], dtype=keras.backend.floatx()) classification = keras.backend.variable(classification) - detections = np.array([[ + other = np.array([[ [1, 2, 3], [4, 5, 6], ]], dtype=keras.backend.floatx()) - detections = keras.backend.variable(detections) + other = keras.backend.variable(other) # compute output - actual = non_maximum_suppression_layer.call([boxes, classification, detections]) + actual = non_maximum_suppression_layer.call([boxes, classification, other]) actual = keras.backend.eval(actual) expected = np.array([[ - [4, 5, 6], + [0, 0, 10, 10, 0, 0, 1, 2, 3], + [0, 0, 10, 10, 0, 1, 4, 5, 6], ]], dtype=keras.backend.floatx()) np.testing.assert_array_equal(actual, expected) @@ -147,7 +148,7 @@ def test_mini_batch(self): ], dtype=keras.backend.floatx()) classification = keras.backend.variable(classification) - detections = np.array([ + other = np.array([ [ [1, 2, 3], [4, 5, 6], @@ -157,18 +158,18 @@ def test_mini_batch(self): [10, 11, 12], ], ], dtype=keras.backend.floatx()) - detections = keras.backend.variable(detections) + other = keras.backend.variable(other) # compute output - actual = non_maximum_suppression_layer.call([boxes, classification, detections]) + actual = non_maximum_suppression_layer.call([boxes, classification, other]) actual = keras.backend.eval(actual) expected = np.array([ [ - [4, 5, 6], + [0, 0, 10, 10, 0, 1, 4, 5, 6], ], [ - [7, 8, 9], + [100, 100, 150, 150, 0, 1, 7, 8, 9], ], ], dtype=keras.backend.floatx())