Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


class TestSanitizeBoundingBoxes:
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10, min_area=10):
boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
Expand All @@ -5816,17 +5816,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
([-1, 1, 10, 20], False), # any < 0
([0, 0, -1, 20], False), # any < 0
([0, 0, -10, -1], False), # any < 0
([0, 0, min_size, 10], True), # H < min_size
([0, 0, 10, min_size], True), # W < min_size
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
([1, 1, 30, 20], True),
([0, 0, 10, 10], True),
([1, 1, 30, 20], True),
([0, 0, min_size, 10], min_size * 10 >= min_area), # H < min_size
([0, 0, 10, min_size], min_size * 10 >= min_area), # W < min_size
([0, 0, W, H], W * H >= min_area),
([1, 1, 30, 20], 29 * 19 >= min_area),
([0, 0, 10, 10], 9 * 9 >= min_area),
([1, 1, 30, 20], 29 * 19 >= min_area),
]

random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, expected_valid_mask = zip(*boxes_and_validity)

boxes = tv_tensors.BoundingBoxes(
boxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
Expand All @@ -5835,7 +5834,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):

return boxes, expected_valid_mask

@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("min_size, min_area", ((1, 1), (10, 1), (10, 101)))
@pytest.mark.parametrize(
"labels_getter",
(
Expand All @@ -5848,15 +5847,15 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
),
)
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type):
def test_transform(self, min_size, min_area, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size, min_area=min_area)
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]

labels = torch.arange(boxes.shape[0])
Expand All @@ -5880,7 +5879,9 @@ def test_transform(self, min_size, labels_getter, sample_type):
img = sample.pop("image")
sample = (img, sample)

out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
out = transforms.SanitizeBoundingBoxes(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(
sample
)

if sample_type is tuple:
out_image = out[0]
Expand Down Expand Up @@ -5977,6 +5978,8 @@ def test_errors_transform(self):

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
with pytest.raises(ValueError, match="min_area must be >= 1"):
transforms.SanitizeBoundingBoxes(min_area=0)
with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBoxes(labels_getter=12)

Expand Down
12 changes: 10 additions & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform):

This transform removes bounding boxes and their associated labels/masks that:

- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.

Expand All @@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform):
cases.

Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_size (float, optional): The size below which bounding boxes are removed. Default is 1.
min_area (float, optional): The area below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if
Expand All @@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform):
def __init__(
self,
min_size: float = 1.0,
min_area: float = 1.0,
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
) -> None:
super().__init__()
Expand All @@ -387,6 +389,10 @@ def __init__(
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size

if min_area < 1:
raise ValueError(f"min_area must be >= 1, got {min_area}.")
self.min_area = min_area

self.labels_getter = labels_getter
self._labels_getter = _parse_labels_getter(labels_getter)

Expand Down Expand Up @@ -422,7 +428,9 @@ def forward(self, *inputs: Any) -> Any:
format=boxes.format,
canvas_size=boxes.canvas_size,
min_size=self.min_size,
min_area=self.min_area,
)

params = dict(valid=valid, labels=labels)
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]

Expand Down
15 changes: 11 additions & 4 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,13 @@ def sanitize_bounding_boxes(
format: Optional[tv_tensors.BoundingBoxFormat] = None,
canvas_size: Optional[Tuple[int, int]] = None,
min_size: float = 1.0,
min_area: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.

This removes bounding boxes that:

- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.

Expand All @@ -346,6 +347,7 @@ def sanitize_bounding_boxes(
(size of the corresponding image/video).
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_area (float, optional) The area below which bounding boxes are removed. Default is 1.

Returns:
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
Expand All @@ -361,7 +363,7 @@ def sanitize_bounding_boxes(
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format.upper()]
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size, min_area=min_area
)
bounding_boxes = bounding_boxes[valid]
else:
Expand All @@ -374,7 +376,11 @@ def sanitize_bounding_boxes(
"Leave those to None or pass bounding_boxes as a pure tensor."
)
valid = _get_sanitize_bounding_boxes_mask(
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
min_size=min_size,
min_area=min_area,
)
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)

Expand All @@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask(
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
min_size: float = 1.0,
min_area: float = 1.0,
) -> torch.Tensor:

bounding_boxes = _convert_bounding_box_format(
Expand All @@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask(

image_h, image_w = canvas_size
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = canvas_size
Expand Down