diff --git a/kerascv/layers/anchor_generators.py b/kerascv/layers/anchor_generators.py deleted file mode 100644 index f490e1e91b..0000000000 --- a/kerascv/layers/anchor_generators.py +++ /dev/null @@ -1,141 +0,0 @@ -import tensorflow as tf - - -class AnchorGenerator(tf.keras.layers.Layer): - def __init__( - self, - image_size, - scales, - aspect_ratios, - anchor_stride=None, - anchor_offset=None, - clip_boxes=True, - norm_coord=True, - name=None, - **kwargs - ): - """Constructs a AnchorGenerator.""" - - self.image_size = image_size - self.image_height = image_size[0] - self.image_width = image_size[1] - self.scales = scales - self.aspect_ratios = aspect_ratios - self.anchor_stride = anchor_stride - self.anchor_offset = anchor_offset - self.clip_boxes = clip_boxes - self.norm_coord = norm_coord - super(AnchorGenerator, self).__init__(name=name, **kwargs) - - def call(self, feature_map_size): - feature_map_height = tf.cast(feature_map_size[0], dtype=tf.float32) - feature_map_width = tf.cast(feature_map_size[1], dtype=tf.float32) - image_height = tf.cast(self.image_height, dtype=tf.float32) - image_width = tf.cast(self.image_width, dtype=tf.float32) - - min_image_size = tf.minimum(image_width, image_height) - - if self.anchor_stride is None: - anchor_stride_height = tf.cast( - min_image_size / feature_map_height, dtype=tf.float32 - ) - anchor_stride_width = tf.cast( - min_image_size / feature_map_width, dtype=tf.float32 - ) - else: - anchor_stride_height = tf.cast(self.anchor_stride[0], dtype=tf.float32) - anchor_stride_width = tf.cast(self.anchor_stride[1], dtype=tf.float32) - - if self.anchor_offset is None: - anchor_offset_height = tf.constant(0.5, dtype=tf.float32) - anchor_offset_width = tf.constant(0.5, dtype=tf.float32) - else: - anchor_offset_height = tf.cast(self.anchor_offset[0], dtype=tf.float32) - anchor_offset_width = tf.cast(self.anchor_offset[1], dtype=tf.float32) - - K = len(self.aspect_ratios) - aspect_ratios_sqrt = tf.cast(tf.sqrt(self.aspect_ratios), tf.float32) - scales = tf.cast(self.scales, dtype=tf.float32) - # [1, 1, K] - anchor_heights = tf.reshape( - (scales / aspect_ratios_sqrt) * min_image_size, (1, 1, -1) - ) - anchor_widths = tf.reshape( - (scales * aspect_ratios_sqrt) * min_image_size, (1, 1, -1) - ) - - # [W] - cx = (tf.range(feature_map_width) + anchor_offset_width) * anchor_stride_width - # [H] - cy = ( - tf.range(feature_map_height) + anchor_offset_height - ) * anchor_stride_height - # [H, W] - cx_grid, cy_grid = tf.meshgrid(cx, cy) - # [H, W, 1] - cx_grid = tf.expand_dims(cx_grid, axis=-1) - cy_grid = tf.expand_dims(cy_grid, axis=-1) - # [H, W, K] - cx_grid = tf.tile(cx_grid, (1, 1, K)) - cy_grid = tf.tile(cy_grid, (1, 1, K)) - # [H, W, K] - anchor_heights = tf.tile( - anchor_heights, (feature_map_height, feature_map_width, 1) - ) - anchor_widths = tf.tile( - anchor_widths, (feature_map_height, feature_map_width, 1) - ) - - # [H, W, K, 2] - box_centers = tf.stack([cy_grid, cx_grid], axis=3) - # [H * W * K, 2] - box_centers = tf.reshape(box_centers, [-1, 2]) - # [H, W, K, 2] - box_sizes = tf.stack([anchor_heights, anchor_widths], axis=3) - # [H * W * K, 2] - box_sizes = tf.reshape(box_sizes, [-1, 2]) - # y_min, x_min, y_max, x_max - # [H * W * K, 4] - box_tensor = tf.concat( - [box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1 - ) - - if self.clip_boxes: - y_min, x_min, y_max, x_max = tf.split( - box_tensor, num_or_size_splits=4, axis=1 - ) - y_min_clipped = tf.maximum(tf.minimum(y_min, self.image_height), 0) - y_max_clipped = tf.maximum(tf.minimum(y_max, self.image_height), 0) - x_min_clipped = tf.maximum(tf.minimum(x_min, self.image_width), 0) - x_max_clipped = tf.maximum(tf.minimum(x_max, self.image_width), 0) - box_tensor = tf.concat( - [y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped], axis=1 - ) - - if self.norm_coord: - box_tensor = box_tensor / tf.constant( - [ - [ - self.image_height, - self.image_width, - self.image_height, - self.image_width, - ] - ], - dtype=box_tensor.dtype, - ) - - return box_tensor - - def get_config(self): - config = { - "image_size": self.image_size, - "scales": self.scales, - "aspect_ratios": self.aspect_ratios, - "anchor_stride": self.anchor_stride, - "anchor_offset": self.anchor_offset, - "clip_boxes": self.clip_boxes, - "norm_coord": self.norm_coord, - } - base_config = super(AnchorGenerator, self).get_config() - return dict(list(base_config.items()) + list(config.items())) diff --git a/kerascv/layers/anchor_generators/__init__.py b/kerascv/layers/anchor_generators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/kerascv/layers/anchor_generators/anchor_generator.py b/kerascv/layers/anchor_generators/anchor_generator.py new file mode 100644 index 0000000000..abcc71cce1 --- /dev/null +++ b/kerascv/layers/anchor_generators/anchor_generator.py @@ -0,0 +1,177 @@ +# Copyright 2020 The Keras CV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + + +class AnchorGenerator(tf.keras.layers.Layer): + """Defines a AnchorGenerator that generates anchor boxes for a single feature map. + + # Attributes: + image_size: A list/tuple of 2 ints, the 1st represents the image height, the 2nd image width. + scales: A list/tuple of positive floats (usually less than 1.) as a fraction to shorter side of `image_size`. + It represents the base anchor size (when aspect ratio is 1.). For example, if `image_size` is (300, 200), + and `scales=[.1]`, then the base anchor size is 20. + aspect_ratios: a list/tuple of positive floats representing the ratio of anchor width to anchor height. + **Must** have the same length as `scales`. For example, if `image_size=(300, 200)`, `scales=[.1]`, + and `aspect_ratios=[.64]`, the base anchor size is 20, then anchor height is 25 and anchor width is 16. + The anchor aspect ratio is independent to the original aspect ratio of image size. + stride: A list/tuple of 2 ints or floats representing the distance between anchor points. + For example, `stride=(30, 40)` means each anchor is separated by 30 pixels in height, and + 40 pixels in width. Defaults to `None`, where anchor stride would be calculated as + `min(image_height, image_width) / feature_map_height` and + `min(image_height, image_width) / feature_map_width`. + offset: A list/tuple of 2 floats between [0., 1.] representing the center of anchor points relative to + the upper-left border of each feature map cell. Defaults to `None`, which is the center of each + feature map cell when `stride=None`, or center of anchor stride otherwise. + clip_boxes: Boolean to represents whether the anchor coordinates should be clipped to the image size. + Defaults to `True`. + normalize_coordinates: Boolean to represents whether the anchor coordinates should be normalized to [0., 1.] + with respect to the image size. Defaults to `True`. + + """ + + def __init__( + self, + image_size, + scales, + aspect_ratios, + stride=None, + offset=None, + clip_boxes=True, + normalize_coordinates=True, + name=None, + **kwargs + ): + """Constructs a AnchorGenerator.""" + + self.image_size = image_size + self.image_height = image_size[0] + self.image_width = image_size[1] + self.scales = scales + self.aspect_ratios = aspect_ratios + self.stride = stride + self.offset = offset + self.clip_boxes = clip_boxes + self.normalize_coordinates = normalize_coordinates + super(AnchorGenerator, self).__init__(name=name, **kwargs) + + def call(self, feature_map_size): + feature_map_height = tf.cast(feature_map_size[0], dtype=tf.float32) + feature_map_width = tf.cast(feature_map_size[1], dtype=tf.float32) + image_height = tf.cast(self.image_height, dtype=tf.float32) + image_width = tf.cast(self.image_width, dtype=tf.float32) + + min_image_size = tf.minimum(image_width, image_height) + + if self.stride is None: + stride_height = tf.cast( + min_image_size / feature_map_height, dtype=tf.float32 + ) + stride_width = tf.cast(min_image_size / feature_map_width, dtype=tf.float32) + else: + stride_height = tf.cast(self.stride[0], dtype=tf.float32) + stride_width = tf.cast(self.stride[1], dtype=tf.float32) + + if self.offset is None: + offset_height = tf.constant(0.5, dtype=tf.float32) + offset_width = tf.constant(0.5, dtype=tf.float32) + else: + offset_height = tf.cast(self.offset[0], dtype=tf.float32) + offset_width = tf.cast(self.offset[1], dtype=tf.float32) + + len_k = len(self.aspect_ratios) + aspect_ratios_sqrt = tf.cast(tf.sqrt(self.aspect_ratios), tf.float32) + scales = tf.cast(self.scales, dtype=tf.float32) + # [1, 1, K] + anchor_heights = tf.reshape( + (scales / aspect_ratios_sqrt) * min_image_size, (1, 1, -1) + ) + anchor_widths = tf.reshape( + (scales * aspect_ratios_sqrt) * min_image_size, (1, 1, -1) + ) + + # [W] + cx = (tf.range(feature_map_width) + offset_width) * stride_width + # [H] + cy = (tf.range(feature_map_height) + offset_height) * stride_height + # [H, W] + cx_grid, cy_grid = tf.meshgrid(cx, cy) + # [H, W, 1] + cx_grid = tf.expand_dims(cx_grid, axis=-1) + cy_grid = tf.expand_dims(cy_grid, axis=-1) + # [H, W, K] + cx_grid = tf.tile(cx_grid, (1, 1, len_k)) + cy_grid = tf.tile(cy_grid, (1, 1, len_k)) + # [H, W, K] + anchor_heights = tf.tile( + anchor_heights, (feature_map_height, feature_map_width, 1) + ) + anchor_widths = tf.tile( + anchor_widths, (feature_map_height, feature_map_width, 1) + ) + + # [H, W, K, 2] + box_centers = tf.stack([cy_grid, cx_grid], axis=3) + # [H * W * K, 2] + box_centers = tf.reshape(box_centers, [-1, 2]) + # [H, W, K, 2] + box_sizes = tf.stack([anchor_heights, anchor_widths], axis=3) + # [H * W * K, 2] + box_sizes = tf.reshape(box_sizes, [-1, 2]) + # y_min, x_min, y_max, x_max + # [H * W * K, 4] + box_tensor = tf.concat( + [box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1 + ) + + if self.clip_boxes: + y_min, x_min, y_max, x_max = tf.split( + box_tensor, num_or_size_splits=4, axis=1 + ) + y_min_clipped = tf.maximum(tf.minimum(y_min, self.image_height), 0) + y_max_clipped = tf.maximum(tf.minimum(y_max, self.image_height), 0) + x_min_clipped = tf.maximum(tf.minimum(x_min, self.image_width), 0) + x_max_clipped = tf.maximum(tf.minimum(x_max, self.image_width), 0) + box_tensor = tf.concat( + [y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped], axis=1 + ) + + if self.normalize_coordinates: + box_tensor = box_tensor / tf.constant( + [ + [ + self.image_height, + self.image_width, + self.image_height, + self.image_width, + ] + ], + dtype=box_tensor.dtype, + ) + + return box_tensor + + def get_config(self): + config = { + "image_size": self.image_size, + "scales": self.scales, + "aspect_ratios": self.aspect_ratios, + "stride": self.stride, + "offset": self.offset, + "clip_boxes": self.clip_boxes, + "normalize_coordinates": self.normalize_coordinates, + } + base_config = super(AnchorGenerator, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/kerascv/layers/anchor_generators/multi_scale_anchor_generator.py b/kerascv/layers/anchor_generators/multi_scale_anchor_generator.py new file mode 100644 index 0000000000..81be946de2 --- /dev/null +++ b/kerascv/layers/anchor_generators/multi_scale_anchor_generator.py @@ -0,0 +1,114 @@ +# Copyright 2020 The Keras CV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from .anchor_generator import AnchorGenerator + + +class MultiScaleAnchorGenerator(tf.keras.layers.Layer): + """Defines a MultiScaleAnchorGenerator that generates anchor boxes for multiple feature maps. + + # Attributes: + image_size: A list/tuple of 2 ints, the 1st represents the image height, the 2nd image width. + scales: A list/tuple of list/tuple of positive floats (usually less than 1.) as a fraction to shorter + side of `image_size`. It represents the base anchor size (when aspect ratio is 1.). + For example, if `image_size=(300, 200)`, and `scales=[[.1]]`, then the base anchor size is 20. + If `image_size=(300, 200)` and `scales=[[.1], [.2]]`, then the base anchor sizes are 20 and 40. + aspect_ratios: a list/tuple of list/tuple of positive floats representing the ratio of anchor width + to anchor height. **Must** have the same length as `scales`. + For example, if `image_size=(300, 200)`, `scales=[[.1]]`, and `aspect_ratios=[[.64]]`, the base anchor + size is 20, then anchor height is 25 and anchor width is 16. If `image_size=(300, 200)`, + `scales=[[.1], [.2]]`, and `aspect_ratios=[[.64], [.1]]`, the base anchor size is 20 and 40, then + the anchor heights are 25 and 40, the anchor widths are 16 and 40. + The anchor aspect ratio is independent to the original aspect ratio of image size. + strides: A list/tuple of list/tuple of 2 ints or floats representing the distance between anchor + points. For example, `stride=[(30, 40)]` means each anchor is separated by 30 pixels in height, + and 40 pixels in width. Defaults to `None`, where anchor stride would be calculated as + `min(image_height, image_width) / feature_map_height` and + `min(image_height, image_width) / feature_map_width` for each feature map. + offsets: A list/tuple of list/tuple of 2 floats between [0., 1.] representing the center of anchor + points relative to the upper-left border of each feature map cell. Defaults to `None`, which is the + center of each feature map cell when `strides=None`, or center of each anchor stride otherwise. + clip_boxes: Boolean to represents whether the anchor coordinates should be clipped to the image size. + Defaults to `True`. + normalize_coordinates: Boolean to represents whether the anchor coordinates should be normalized to [0., 1.] + with respect to the image size. Defaults to `True`. + + """ + + def __init__( + self, + image_size, + scales, + aspect_ratios, + strides=None, + offsets=None, + clip_boxes=True, + normalize_coordinates=True, + name=None, + **kwargs + ): + self.image_size = image_size + self.image_height = image_size[0] + self.image_width = image_size[1] + self.scales = scales + self.aspect_ratios = aspect_ratios + if strides is None: + strides = [None] * len(scales) + if offsets is None: + offsets = [None] * len(scales) + self.strides = strides + self.offsets = offsets + self.clip_boxes = clip_boxes + self.normalize_coordinates = normalize_coordinates + self.anchor_generators = [] + for (i, (scale_list, aspect_ratio_list, stride, offset)) in enumerate( + zip(scales, aspect_ratios, strides, offsets) + ): + self.anchor_generators.append( + AnchorGenerator( + image_size, + scales=scale_list, + aspect_ratios=aspect_ratio_list, + stride=stride, + offset=offset, + clip_boxes=clip_boxes, + normalize_coordinates=normalize_coordinates, + name="anchor_generator_" + str(i), + ) + ) + super(MultiScaleAnchorGenerator, self).__init__(name=name, **kwargs) + + def call(self, feature_map_sizes): + result = [] + for feature_map_size, anchor_generator in zip( + feature_map_sizes, self.anchor_generators + ): + anchors = anchor_generator(feature_map_size) + anchors = tf.reshape(anchors, (-1, 4)) + result.append(anchors) + return tf.concat(result, axis=0) + + def get_config(self): + config = { + "image_size": self.image_size, + "scales": self.scales, + "aspect_ratios": self.aspect_ratios, + "strides": self.strides, + "offsets": self.offsets, + "clip_boxes": self.clip_boxes, + "normalize_coordinates": self.normalize_coordinates, + } + base_config = super(MultiScaleAnchorGenerator, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/kerascv/layers/anchor_generators/__init__.py b/tests/kerascv/layers/anchor_generators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/kerascv/layers/anchor_generators/anchor_generator_test.py b/tests/kerascv/layers/anchor_generators/anchor_generator_test.py new file mode 100644 index 0000000000..cc94a7ceaf --- /dev/null +++ b/tests/kerascv/layers/anchor_generators/anchor_generator_test.py @@ -0,0 +1,238 @@ +# Copyright 2020 The Keras CV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from kerascv.layers.anchor_generators import anchor_generator + + +def test_single_scale_absolute_coordinate(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2], + aspect_ratios=[1.0], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [45, 45, 105, 105], + [45, 195, 105, 255], + [195, 45, 255, 105], + [195, 195, 255, 255], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_single_scale_non_square_image(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 200), + scales=[0.2], + aspect_ratios=[1.0], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [[30, 30, 70, 70], [30, 130, 70, 170], [130, 30, 170, 70], [130, 130, 170, 170]] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_single_scale_normalized_coordinate(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2], + aspect_ratios=[1.0], + clip_boxes=False, + normalize_coordinates=True, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [0.15, 0.15, 0.35, 0.35], + [0.15, 0.65, 0.35, 0.85], + [0.65, 0.15, 0.85, 0.35], + [0.65, 0.65, 0.85, 0.85], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_single_scale_customized_stride(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2], + aspect_ratios=[1.0], + stride=[100, 100], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + # The center of absolute anchor points would be [50, 50], [50, 150], [150, 50] and [150, 150] + expected_out = np.asarray( + [[20, 20, 80, 80], [20, 120, 80, 180], [120, 20, 180, 80], [120, 120, 180, 180]] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_single_scale_customized_offset(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2], + aspect_ratios=[1.0], + offset=[0.3, 0.3], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + # The first center of absolute anchor points would be 300 / 2 * 0.3 = 45, the second would be 45 + 150 = 195 + expected_out = np.asarray( + [[15, 15, 75, 75], [15, 165, 75, 225], [165, 15, 225, 75], [165, 165, 225, 225]] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_over_scale_absolute_coordinate_no_clip(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.7], + aspect_ratios=[1.0], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [-30, -30, 180, 180], + [-30, 120, 180, 330], + [120, -30, 330, 180], + [120, 120, 330, 330], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_over_scale_absolute_coordinate_clip(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.7], + aspect_ratios=[1.0], + clip_boxes=True, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [[0, 0, 180, 180], [0, 120, 180, 300], [120, 0, 300, 180], [120, 120, 300, 300]] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_over_scale_normalized_coordinate_no_clip(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.7], + aspect_ratios=[1.0], + clip_boxes=False, + normalize_coordinates=True, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [-0.1, -0.1, 0.6, 0.6], + [-0.1, 0.4, 0.6, 1.1], + [0.4, -0.1, 1.1, 0.6], + [0.4, 0.4, 1.1, 1.1], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_over_scale_normalized_coordinate_clip(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.7], + aspect_ratios=[1.0], + clip_boxes=True, + normalize_coordinates=True, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [0.0, 0.0, 0.6, 0.6], + [0.0, 0.4, 0.6, 1.0], + [0.4, 0.0, 1.0, 0.6], + [0.4, 0.4, 1.0, 1.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_aspect_ratios(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2, 0.2], + aspect_ratios=[0.64, 1.0], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + # height is 300 * 0.2 / 0.8 = 75; width is 300 * 0.2 * 0.8 = 48 + expected_out = np.asarray( + [ + [37.5, 51.0, 112.5, 99.0], + [45.0, 45.0, 105.0, 105.0], + [37.5, 201.0, 112.5, 249.0], + [45.0, 195.0, 105.0, 255.0], + [187.5, 51.0, 262.5, 99.0], + [195.0, 45.0, 255.0, 105.0], + [187.5, 201.0, 262.5, 249.0], + [195.0, 195.0, 255.0, 255.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_scales(): + anchor_gen = anchor_generator.AnchorGenerator( + image_size=(300, 300), + scales=[0.2, 0.5], + aspect_ratios=[1.0, 1.0], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen((2, 2)) + expected_out = np.asarray( + [ + [45.0, 45.0, 105.0, 105.0], + [0.0, 0.0, 150.0, 150.0], + [45.0, 195.0, 105.0, 255.0], + [0.0, 150.0, 150.0, 300.0], + [195.0, 45.0, 255.0, 105.0], + [150.0, 0.0, 300.0, 150.0], + [195.0, 195.0, 255.0, 255.0], + [150.0, 150.0, 300.0, 300.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_config_with_custom_name(): + layer = anchor_generator.AnchorGenerator( + (300, 300), [1.0], [1.0], name="anchor_generator" + ) + config = layer.get_config() + layer_1 = anchor_generator.AnchorGenerator.from_config(config) + np.testing.assert_equal(layer_1.name, layer.name) diff --git a/tests/kerascv/layers/anchor_generators/multi_scale_anchor_generator_test.py b/tests/kerascv/layers/anchor_generators/multi_scale_anchor_generator_test.py new file mode 100644 index 0000000000..1f7ea0b979 --- /dev/null +++ b/tests/kerascv/layers/anchor_generators/multi_scale_anchor_generator_test.py @@ -0,0 +1,224 @@ +# Copyright 2020 The Keras CV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from kerascv.layers.anchor_generators import multi_scale_anchor_generator + + +def test_single_feature_map_absolute_coordinate(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.2]], + aspect_ratios=[[1.0]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(2, 2)]) + expected_out = np.asarray( + [ + [45, 45, 105, 105], + [45, 195, 105, 255], + [195, 45, 255, 105], + [195, 195, 255, 255], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_single_feature_map_multi_aspect_ratios(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.2, 0.2]], + aspect_ratios=[[0.64, 1.0]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(2, 2)]) + # height is 300 * 0.2 / 0.8 = 75; width is 300 * 0.2 * 0.8 = 48 + expected_out = np.asarray( + [ + [37.5, 51.0, 112.5, 99.0], + [45.0, 45.0, 105.0, 105.0], + [37.5, 201.0, 112.5, 249.0], + [45.0, 195.0, 105.0, 255.0], + [187.5, 51.0, 262.5, 99.0], + [195.0, 45.0, 255.0, 105.0], + [187.5, 201.0, 262.5, 249.0], + [195.0, 195.0, 255.0, 255.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_feature_maps_absolute_coordinate(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.1], [0.2]], + aspect_ratios=[[1.0], [1.0]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(3, 3), (2, 2)]) + # The first height and width is 30, the second height and width is 60. + expected_out = np.asarray( + [ + [35.0, 35.0, 65.0, 65.0], + [35.0, 135.0, 65.0, 165.0], + [35.0, 235.0, 65.0, 265.0], + [135.0, 35.0, 165.0, 65.0], + [135.0, 135.0, 165.0, 165.0], + [135.0, 235.0, 165.0, 265.0], + [235.0, 35.0, 265.0, 65.0], + [235.0, 135.0, 265.0, 165.0], + [235.0, 235.0, 265.0, 265.0], + [45.0, 45.0, 105.0, 105.0], + [45.0, 195.0, 105.0, 255.0], + [195.0, 45.0, 255.0, 105.0], + [195.0, 195.0, 255.0, 255.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_feature_maps_customized_stride(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.1], [0.2]], + aspect_ratios=[[1.0], [1.0]], + strides=[[120, 120], [160, 160]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(3, 3), (2, 2)]) + # The first center of anchor point for the first feature map is 120 * 0.5 = 60, then 180 + # The first center of anchor point for the second feature map is 160 * 0.5 = 80, then 240 + expected_out = np.asarray( + [ + [45.0, 45.0, 75.0, 75.0], + [45.0, 165.0, 75.0, 195.0], + [45.0, 285.0, 75.0, 315.0], + [165.0, 45.0, 195.0, 75.0], + [165.0, 165.0, 195.0, 195.0], + [165.0, 285.0, 195.0, 315.0], + [285.0, 45.0, 315.0, 75.0], + [285.0, 165.0, 315.0, 195.0], + [285.0, 285.0, 315.0, 315.0], + [50.0, 50.0, 110.0, 110.0], + [50.0, 210.0, 110.0, 270.0], + [210.0, 50.0, 270.0, 110.0], + [210.0, 210.0, 270.0, 270.0], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_feature_maps_customized_offset(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.1], [0.2]], + aspect_ratios=[[1.0], [1.0]], + offsets=[[0.2, 0.2], [0.3, 0.3]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(3, 3), (2, 2)]) + # The first center of anchor point for the first feature map is 100 * 0.2 = 20, then 120 + # The first center of anchor point for the second feature map is 150 * 0.3 = 45, then 195 + expected_out = np.asarray( + [ + [5.0, 5.0, 35.0, 35.0], + [5.0, 105.0, 35.0, 135.0], + [5.0, 205.0, 35.0, 235.0], + [105.0, 5.0, 135.0, 35.0], + [105.0, 105.0, 135.0, 135.0], + [105.0, 205.0, 135.0, 235.0], + [205.0, 5.0, 235.0, 35.0], + [205.0, 105.0, 235.0, 135.0], + [205.0, 205.0, 235.0, 235.0], + [15, 15, 75, 75], + [15, 165, 75, 225], + [165, 15, 225, 75], + [165, 165, 225, 225], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_feature_maps_over_scale_absolute_coordinate_no_clip(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.1], [0.7]], + aspect_ratios=[[1.0], [1.0]], + clip_boxes=False, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(3, 3), (2, 2)]) + # The first height and width is 30, the second height and width is 60. + expected_out = np.asarray( + [ + [35.0, 35.0, 65.0, 65.0], + [35.0, 135.0, 65.0, 165.0], + [35.0, 235.0, 65.0, 265.0], + [135.0, 35.0, 165.0, 65.0], + [135.0, 135.0, 165.0, 165.0], + [135.0, 235.0, 165.0, 265.0], + [235.0, 35.0, 265.0, 65.0], + [235.0, 135.0, 265.0, 165.0], + [235.0, 235.0, 265.0, 265.0], + [-30, -30, 180, 180], + [-30, 120, 180, 330], + [120, -30, 330, 180], + [120, 120, 330, 330], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_multi_feature_maps_over_scale_absolute_coordinate_clip(): + anchor_gen = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + image_size=(300, 300), + scales=[[0.1], [0.7]], + aspect_ratios=[[1.0], [1.0]], + clip_boxes=True, + normalize_coordinates=False, + ) + anchor_out = anchor_gen([(3, 3), (2, 2)]) + # The first height and width is 30, the second height and width is 60. + expected_out = np.asarray( + [ + [35.0, 35.0, 65.0, 65.0], + [35.0, 135.0, 65.0, 165.0], + [35.0, 235.0, 65.0, 265.0], + [135.0, 35.0, 165.0, 65.0], + [135.0, 135.0, 165.0, 165.0], + [135.0, 235.0, 165.0, 265.0], + [235.0, 35.0, 265.0, 65.0], + [235.0, 135.0, 265.0, 165.0], + [235.0, 235.0, 265.0, 265.0], + [0, 0, 180, 180], + [0, 120, 180, 300], + [120, 0, 300, 180], + [120, 120, 300, 300], + ] + ).astype(np.float32) + np.testing.assert_allclose(expected_out, anchor_out) + + +def test_config_with_custom_name(): + layer = multi_scale_anchor_generator.MultiScaleAnchorGenerator( + (300, 300), [[1.0]], [[1.0]], name="multi_anchor_generator" + ) + config = layer.get_config() + layer_1 = multi_scale_anchor_generator.MultiScaleAnchorGenerator.from_config(config) + np.testing.assert_equal(layer_1.name, layer.name) diff --git a/tests/kerascv/layers/anchor_generators_test.py b/tests/kerascv/layers/anchor_generators_test.py deleted file mode 100644 index 9f9640bf90..0000000000 --- a/tests/kerascv/layers/anchor_generators_test.py +++ /dev/null @@ -1,121 +0,0 @@ -import numpy as np -from kerascv.layers.anchor_generators import AnchorGenerator - - -def test_single_scale_absolute_coordinate(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.2], aspect_ratios=[1.], clip_boxes=False, norm_coord=False) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[45, 45, 105, 105], - [45, 195, 105, 255], - [195, 45, 255, 105], - [195, 195, 255, 255]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_single_scale_normalized_coordinate(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.2], aspect_ratios=[1.], clip_boxes=False, norm_coord=True - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[.15, .15, .35, .35], - [.15, .65, .35, .85], - [.65, .15, .85, .35], - [.65, .65, .85, .85]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_over_scale_absolute_coordinate_no_clip(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.7], aspect_ratios=[1.], clip_boxes=False, norm_coord=False - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[-30, -30, 180, 180], - [-30, 120, 180, 330], - [120, -30, 330, 180], - [120, 120, 330, 330]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_over_scale_absolute_coordinate_clip(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.7], aspect_ratios=[1.], clip_boxes=True, norm_coord=False - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[0, 0, 180, 180], - [0, 120, 180, 300], - [120, 0, 300, 180], - [120, 120, 300, 300]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_over_scale_normalized_coordinate_no_clip(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.7], aspect_ratios=[1.], clip_boxes=False, norm_coord=True - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[-0.1, -0.1, 0.6, 0.6], - [-0.1, 0.4, 0.6, 1.1], - [0.4, -0.1, 1.1, 0.6], - [0.4, 0.4, 1.1, 1.1]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_over_scale_normalized_coordinate_clip(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.7], aspect_ratios=[1.], clip_boxes=True, norm_coord=True - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[0., 0., 0.6, 0.6], - [0., 0.4, 0.6, 1.0], - [0.4, 0., 1.0, 0.6], - [0.4, 0.4, 1.0, 1.0]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_multi_aspect_ratios(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.2, .2], aspect_ratios=[.64, 1.], clip_boxes=False, norm_coord=False - ) - anchor_out = anchor_gen((2, 2)) - # height is 300 * 0.2 / 0.8 = 75; width is 300 * 0.2 * 0.8 = 48 - expected_out = np.asarray( - [[37.5, 51., 112.5, 99.], - [45., 45., 105., 105.], - [37.5, 201., 112.5, 249.], - [45., 195., 105., 255.], - [187.5, 51., 262.5, 99.], - [195., 45., 255., 105.], - [187.5, 201., 262.5, 249.], - [195., 195., 255., 255.]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_multi_scales(): - anchor_gen = AnchorGenerator( - image_size=(300, 300), scales=[.2, .5], aspect_ratios=[1., 1.], clip_boxes=False, norm_coord=False - ) - anchor_out = anchor_gen((2, 2)) - expected_out = np.asarray( - [[45., 45., 105., 105.], - [0., 0., 150., 150.], - [45., 195., 105., 255.], - [0., 150., 150., 300.], - [195., 45., 255., 105.], - [150., 0., 300., 150.], - [195., 195., 255., 255.], - [150., 150., 300., 300.]]).astype(np.float32) - np.testing.assert_allclose(anchor_out, expected_out) - - -def test_config_with_custom_name(): - layer = AnchorGenerator((300, 300), [1.], [1.], name='anchor_generator') - config = layer.get_config() - layer_1 = AnchorGenerator.from_config(config) - np.testing.assert_equal(layer_1.name, layer.name)