From 5e947285cd84fab66b348d3f9d0aa027ecc73fdc Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Tue, 16 Jun 2020 10:40:15 -0700 Subject: [PATCH 1/2] [Object_Detection] Add IOUSimilarity layer. --- kerascv/layers/iou_similarity.py | 61 ++++++++++ tests/kerascv/layers/iou_similarity_test.py | 123 ++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 kerascv/layers/iou_similarity.py create mode 100644 tests/kerascv/layers/iou_similarity_test.py diff --git a/kerascv/layers/iou_similarity.py b/kerascv/layers/iou_similarity.py new file mode 100644 index 0000000000..34b93d5d5a --- /dev/null +++ b/kerascv/layers/iou_similarity.py @@ -0,0 +1,61 @@ +import tensorflow as tf + + +class IOUSimilarity(tf.keras.layers.Layer): + """Defines a IOUSimilarity that calculates the IOU between ground truth boxes and anchors.""" + + def __init__(self, name=None, **kwargs): + super(IOUSimilarity, self).__init__(name=name, **kwargs) + + # TODO: support ragged ground_truth_boxes + def call(self, ground_truth_boxes, anchors): + # ground_truth_box [n_gt_boxes, box_dim] or [batch_size, n_gt_boxes, box_dim] + # anchor [n_anchors, box_dim] + def iou(ground_truth_box, anchor): + # [n_anchors, 1] + y_min_anchors, x_min_anchors, y_max_anchors, x_max_anchors = tf.split( + anchor, num_or_size_splits=4, axis=-1 + ) + # [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1] + y_min_gt, x_min_gt, y_max_gt, x_max_gt = tf.split( + ground_truth_box, num_or_size_splits=4, axis=-1 + ) + # [n_anchors] + anchor_areas = tf.squeeze( + (y_max_anchors - y_min_anchors) * (x_max_anchors - x_min_anchors), [1] + ) + # [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1] + gt_areas = (y_max_gt - y_min_gt) * (x_max_gt - x_min_gt) + + # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors] + max_y_min = tf.maximum(y_min_gt, tf.transpose(y_min_anchors)) + min_y_max = tf.minimum(y_max_gt, tf.transpose(y_max_anchors)) + intersect_heights = tf.maximum( + tf.constant(0, dtype=ground_truth_box.dtype), (min_y_max - max_y_min) + ) + + # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors] + max_x_min = tf.maximum(x_min_gt, tf.transpose(x_min_anchors)) + min_x_max = tf.minimum(x_max_gt, tf.transpose(x_max_anchors)) + intersect_widths = tf.maximum( + tf.constant(0, dtype=ground_truth_box.dtype), (min_x_max - max_x_min) + ) + + # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors] + intersections = intersect_heights * intersect_widths + + # [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors] + unions = gt_areas + anchor_areas - intersections + + return tf.cast(tf.truediv(intersections, unions), tf.float32) + + if anchors.shape.ndims == 2: + return iou(ground_truth_boxes, anchors) + elif anchors.shape.ndims == 3: + return tf.map_fn( + lambda x: iou(x[0], x[1]), + elems=[ground_truth_boxes, anchors], + dtype=tf.float32, + parallel_iterations=32, + back_prop=False, + ) diff --git a/tests/kerascv/layers/iou_similarity_test.py b/tests/kerascv/layers/iou_similarity_test.py new file mode 100644 index 0000000000..8dc1eeff8d --- /dev/null +++ b/tests/kerascv/layers/iou_similarity_test.py @@ -0,0 +1,123 @@ +import numpy as np +import tensorflow as tf +from kerascv.layers.iou_similarity import IOUSimilarity + + +def test_iou_basic_absolute_coordinate(): + # both gt box and two anchors are size 4 + # the intersection between gt box and first anchor is 1 and union is 7 + # the intersection between gt box and second anchor is 0 and union is 8 + gt_boxes = tf.constant([[[0, 2, 2, 4]]]) + anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + # batch_size = 1, n_gt_boxes = 1, n_anchors = 2 + expected_out = np.asarray([1 / 7, 0]).astype(np.float32).reshape((1, 1, 2)) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_basic_normalized_coordinate(): + # both gt box and two anchors are size 1 + # the intersection between gt box and first anchor is 1 and union is 7 + # the intersection between gt box and second anchor is 0 and union is 8 + gt_boxes = tf.constant([[[0, 0.5, 0.5, 1.0]]]) + anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0, 0.75]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = np.asarray([1 / 7, 0]).astype(np.float32).reshape((1, 1, 2)) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_multi_gt_multi_anchor_absolute_coordinate(): + # batch_size = 1, n_gt_boxes = 2 + # [1, 2, 4] + gt_boxes = tf.constant([[[0, 2, 2, 4], [-1, 1, 1, 3]]]) + # [2, 4] + anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = ( + np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((1, 2, 2)) + ) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_batched_gt_multi_anchor_absolute_coordinate(): + # batch_size = 2, n_gt_boxes = 1 + # [2, 1, 4] + gt_boxes = tf.constant([[[0, 2, 2, 4]], [[-1, 1, 1, 3]]]) + # [2, 4] + anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = ( + np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((2, 1, 2)) + ) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_batched_gt_batched_anchor_absolute_coordinate(): + # batch_size = 2, n_gt_boxes = 1 + # [2, 1, 4] + gt_boxes = tf.constant([[[0, 2, 2, 4]], [[-1, 1, 1, 3]]]) + # [2, 1, 4] + anchors = tf.constant([[[1, 1, 3, 3]], [[-2, 1, 0, 3]]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = np.asarray([[1 / 7], [1 / 3]]).astype(np.float32).reshape((2, 1, 1)) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_multi_gt_multi_anchor_normalized_coordinate(): + # batch_size = 1, n_gt_boxes = 2 + # [1, 2, 4] + gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0], [-0.25, 0.25, 0.25, 0.75]]]) + # [2, 4] + anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = ( + np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((1, 2, 2)) + ) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_batched_gt_multi_anchor_normalized_coordinate(): + # batch_size = 2, n_gt_boxes = 1 + # [2, 1, 4] + gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75]]]) + # [2, 4] + anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = ( + np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((2, 1, 2)) + ) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_batched_gt_batched_anchor_normalized_coordinate(): + # batch_size = 2, n_gt_boxes = 1 + # [2, 1, 4] + gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75]]]) + # [2, 1, 4] + anchors = tf.constant([[[0.25, 0.25, 0.75, 0.75]], [[-0.5, 0.25, 0.0, 0.75]]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = np.asarray([[1 / 7], [1 / 3]]).astype(np.float32).reshape((2, 1, 1)) + np.testing.assert_allclose(expected_out, similarity) + + +def test_iou_large(): + # [2, 4] + gt_boxes = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) + # [3, 4] + anchors = tf.constant( + [[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], [0.0, 0.0, 20.0, 20.0]] + ) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = np.asarray([[2 / 16, 0, 6 / 400], [1 / 16, 0.0, 5 / 400]]).astype( + np.float32 + ) + np.testing.assert_allclose(expected_out, similarity) From 7a00ce0429fd0089d42c6fb486cb7e62c7379bb3 Mon Sep 17 00:00:00 2001 From: tanzhenyu Date: Tue, 16 Jun 2020 12:13:33 -0700 Subject: [PATCH 2/2] Add ragged tensor support. --- kerascv/layers/iou_similarity.py | 28 ++++++++++- tests/kerascv/layers/iou_similarity_test.py | 55 +++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/kerascv/layers/iou_similarity.py b/kerascv/layers/iou_similarity.py index 34b93d5d5a..7028142ed8 100644 --- a/kerascv/layers/iou_similarity.py +++ b/kerascv/layers/iou_similarity.py @@ -2,12 +2,15 @@ class IOUSimilarity(tf.keras.layers.Layer): - """Defines a IOUSimilarity that calculates the IOU between ground truth boxes and anchors.""" + """Defines a IOUSimilarity that calculates the IOU between ground truth boxes and anchors. + + Calling the layer with `ground_truth_boxes` and `anchors`, `ground_truth_boxes` can be a batched + `tf.Tensor` or `tf.RaggedTensor`, while `anchors` can be a batched or un-batched `tf.Tensor`. + """ def __init__(self, name=None, **kwargs): super(IOUSimilarity, self).__init__(name=name, **kwargs) - # TODO: support ragged ground_truth_boxes def call(self, ground_truth_boxes, anchors): # ground_truth_box [n_gt_boxes, box_dim] or [batch_size, n_gt_boxes, box_dim] # anchor [n_anchors, box_dim] @@ -49,6 +52,27 @@ def iou(ground_truth_box, anchor): return tf.cast(tf.truediv(intersections, unions), tf.float32) + if isinstance(ground_truth_boxes, tf.RaggedTensor): + if anchors.shape.ndims == 2: + return tf.map_fn( + lambda x: iou(x, anchors), + elems=ground_truth_boxes, + parallel_iterations=32, + back_prop=False, + fn_output_signature=tf.RaggedTensorSpec( + dtype=tf.float32, ragged_rank=0 + ), + ) + else: + return tf.map_fn( + lambda x: iou(x[0], x[1]), + elems=[ground_truth_boxes, anchors], + parallel_iterations=32, + back_prop=False, + fn_output_signature=tf.RaggedTensorSpec( + dtype=tf.float32, ragged_rank=0 + ), + ) if anchors.shape.ndims == 2: return iou(ground_truth_boxes, anchors) elif anchors.shape.ndims == 3: diff --git a/tests/kerascv/layers/iou_similarity_test.py b/tests/kerascv/layers/iou_similarity_test.py index 8dc1eeff8d..cd465d0430 100644 --- a/tests/kerascv/layers/iou_similarity_test.py +++ b/tests/kerascv/layers/iou_similarity_test.py @@ -121,3 +121,58 @@ def test_iou_large(): np.float32 ) np.testing.assert_allclose(expected_out, similarity) + + +def test_ragged_gt_boxes_multi_anchor_absolute_coordinate(): + # [2, ragged, 4] + gt_boxes = tf.ragged.constant( + [[[0, 2, 2, 4]], [[-1, 1, 1, 3], [-1, 1, 2, 3]]], ragged_rank=1 + ) + # [2, 4] + anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = tf.ragged.constant([[[1 / 7, 0.0]], [[0.0, 1 / 3], [1 / 4, 1 / 4]]]) + np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy()) + + +def test_ragged_gt_boxes_multi_anchor_normalized_coordinate(): + # [2, ragged, 4] + gt_boxes = tf.ragged.constant( + [[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]], + ragged_rank=1, + ) + # [2, 4] + anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = tf.ragged.constant([[[1 / 7, 0.0]], [[0.0, 1 / 3], [1 / 4, 1 / 4]]]) + np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy()) + + +def test_ragged_gt_boxes_batched_anchor_normalized_coordinate(): + # [2, ragged, 4] + gt_boxes = tf.ragged.constant( + [[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]], + ragged_rank=1, + ) + # [2, 1, 4] + anchors = tf.constant([[[0.25, 0.25, 0.75, 0.75]], [[-0.5, 0.25, 0.0, 0.75]]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = tf.ragged.constant([[[1 / 7]], [[1 / 3], [1 / 4]]]) + np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy()) + + +def test_ragged_gt_boxes_empty_anchor(): + # [2, ragged, 4] + gt_boxes = tf.ragged.constant( + [[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]], + ragged_rank=1, + ) + # [2, 4] + anchors = tf.constant([[0.25, 0.25, 0.25, 0.25], [-0.5, 0.25, 0.0, 0.75]]) + iou_layer = IOUSimilarity() + similarity = iou_layer(gt_boxes, anchors) + expected_out = tf.ragged.constant([[[0.0, 0.0]], [[0.0, 1 / 3], [0.0, 1 / 4]]]) + np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy())