diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8a47a589508..e47128e611b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): class TestSanitizeBoundingBoxes: - def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): + def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10, min_area=10): boxes_and_validity = [ ([0, 1, 10, 1], False), # Y1 == Y2 ([0, 1, 0, 20], False), # X1 == X2 @@ -5816,17 +5816,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): ([-1, 1, 10, 20], False), # any < 0 ([0, 0, -1, 20], False), # any < 0 ([0, 0, -10, -1], False), # any < 0 - ([0, 0, min_size, 10], True), # H < min_size - ([0, 0, 10, min_size], True), # W < min_size - ([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1? - ([1, 1, 30, 20], True), - ([0, 0, 10, 10], True), - ([1, 1, 30, 20], True), + ([0, 0, min_size, 10], min_size * 10 >= min_area), # H < min_size + ([0, 0, 10, min_size], min_size * 10 >= min_area), # W < min_size + ([0, 0, W, H], W * H >= min_area), + ([1, 1, 30, 20], 29 * 19 >= min_area), + ([0, 0, 10, 10], 9 * 9 >= min_area), + ([1, 1, 30, 20], 29 * 19 >= min_area), ] random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases boxes, expected_valid_mask = zip(*boxes_and_validity) - boxes = tv_tensors.BoundingBoxes( boxes, format=tv_tensors.BoundingBoxFormat.XYXY, @@ -5835,7 +5834,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): return boxes, expected_valid_mask - @pytest.mark.parametrize("min_size", (1, 10)) + @pytest.mark.parametrize("min_size, min_area", ((1, 1), (10, 1), (10, 101))) @pytest.mark.parametrize( "labels_getter", ( @@ -5848,7 +5847,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10): ), ) @pytest.mark.parametrize("sample_type", (tuple, dict)) - def test_transform(self, min_size, labels_getter, sample_type): + def test_transform(self, min_size, min_area, labels_getter, sample_type): if sample_type is tuple and not isinstance(labels_getter, str): # The "lambda inputs: inputs["labels"]" labels_getter used in this test @@ -5856,7 +5855,7 @@ def test_transform(self, min_size, labels_getter, sample_type): return H, W = 256, 128 - boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size) + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size, min_area=min_area) valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid] labels = torch.arange(boxes.shape[0]) @@ -5880,7 +5879,9 @@ def test_transform(self, min_size, labels_getter, sample_type): img = sample.pop("image") sample = (img, sample) - out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) + out = transforms.SanitizeBoundingBoxes(min_size=min_size, min_area=min_area, labels_getter=labels_getter)( + sample + ) if sample_type is tuple: out_image = out[0] @@ -5977,6 +5978,8 @@ def test_errors_transform(self): with pytest.raises(ValueError, match="min_size must be >= 1"): transforms.SanitizeBoundingBoxes(min_size=0) + with pytest.raises(ValueError, match="min_area must be >= 1"): + transforms.SanitizeBoundingBoxes(min_area=0) with pytest.raises(ValueError, match="labels_getter should either be 'default'"): transforms.SanitizeBoundingBoxes(labels_getter=12) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 6d62539ccd7..93198f0009d 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform): This transform removes bounding boxes and their associated labels/masks that: - - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. @@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform): cases. Args: - min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + min_size (float, optional): The size below which bounding boxes are removed. Default is 1. + min_area (float, optional): The area below which bounding boxes are removed. Default is 1. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input (or anything else that needs to be sanitized along with the bounding boxes). By default, this will try to find a "labels" key in the input (case-insensitive), if @@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform): def __init__( self, min_size: float = 1.0, + min_area: float = 1.0, labels_getter: Union[Callable[[Any], Any], str, None] = "default", ) -> None: super().__init__() @@ -387,6 +389,10 @@ def __init__( raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size + if min_area < 1: + raise ValueError(f"min_area must be >= 1, got {min_area}.") + self.min_area = min_area + self.labels_getter = labels_getter self._labels_getter = _parse_labels_getter(labels_getter) @@ -422,7 +428,9 @@ def forward(self, *inputs: Any) -> Any: format=boxes.format, canvas_size=boxes.canvas_size, min_size=self.min_size, + min_area=self.min_area, ) + params = dict(valid=valid, labels=labels) flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 84b686d50f9..daf6c5560db 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -322,12 +322,13 @@ def sanitize_bounding_boxes( format: Optional[tv_tensors.BoundingBoxFormat] = None, canvas_size: Optional[Tuple[int, int]] = None, min_size: float = 1.0, + min_area: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Remove degenerate/invalid bounding boxes and return the corresponding indexing mask. This removes bounding boxes that: - - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals. @@ -346,6 +347,7 @@ def sanitize_bounding_boxes( (size of the corresponding image/video). Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + min_area (float, optional) The area below which bounding boxes are removed. Default is 1. Returns: out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. @@ -361,7 +363,7 @@ def sanitize_bounding_boxes( if isinstance(format, str): format = tv_tensors.BoundingBoxFormat[format.upper()] valid = _get_sanitize_bounding_boxes_mask( - bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size + bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size, min_area=min_area ) bounding_boxes = bounding_boxes[valid] else: @@ -374,7 +376,11 @@ def sanitize_bounding_boxes( "Leave those to None or pass bounding_boxes as a pure tensor." ) valid = _get_sanitize_bounding_boxes_mask( - bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, + min_size=min_size, + min_area=min_area, ) bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) @@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask( format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int], min_size: float = 1.0, + min_area: float = 1.0, ) -> torch.Tensor: bounding_boxes = _convert_bounding_box_format( @@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask( image_h, image_w = canvas_size ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] - valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) + valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = canvas_size