diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 0735eff6575..b24e9a41ff7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch cx = torch.randint(1, width - 1, ()) cy = torch.randint(1, height - 1, ()) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) - h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) parts = (cx, cy, w, h) else: raise pytest.UsageError() @@ -413,6 +413,14 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def center_crop_bounding_box(): + for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): + yield SampleInput( + bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size + ) + + @pytest.mark.parametrize( "kernel", [ @@ -1273,3 +1281,59 @@ def _compute_expected_mask(mask, pcoeffs_): else: expected_masks = expected_masks[0] torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "output_size", + [(18, 18), [18, 15], (16, 19), [12], [46, 48]], +) +def test_correctness_center_crop_bounding_box(device, output_size): + def _compute_expected_bbox(bbox, output_size_): + format_ = bbox.format + image_size_ = bbox.image_size + bbox = convert_bounding_box_format(bbox, format_, features.BoundingBoxFormat.XYWH) + + if len(output_size_) == 1: + output_size_.append(output_size_[-1]) + + cy = int(round((image_size_[0] - output_size_[0]) * 0.5)) + cx = int(round((image_size_[1] - output_size_[1]) * 0.5)) + out_bbox = [ + bbox[0].item() - cx, + bbox[1].item() - cy, + bbox[2].item(), + bbox[3].item(), + ] + out_bbox = features.BoundingBox( + out_bbox, + format=features.BoundingBoxFormat.XYWH, + image_size=output_size_, + dtype=bbox.dtype, + device=bbox.device, + ) + return convert_bounding_box_format(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) + + for bboxes in make_bounding_boxes( + image_sizes=[(32, 32), (24, 33), (32, 25)], + extra_dims=((4,),), + ): + bboxes = bboxes.to(device) + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + + output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, output_size, bboxes_image_size) + + if bboxes.ndim < 2: + bboxes = [bboxes] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) + + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] + torch.testing.assert_close(output_boxes, expected_bboxes) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 59c53e70c26..cac2946b46e 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -45,6 +45,7 @@ resize_image_tensor, resize_image_pil, resize_segmentation_mask, + center_crop_bounding_box, center_crop_image_tensor, center_crop_image_pil, resized_crop_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ae28bd84874..7d6e26451c9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -619,6 +619,17 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) +def center_crop_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + output_size: List[int], + image_size: Tuple[int, int], +) -> torch.Tensor: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) + return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) + + def resized_crop_image_tensor( img: torch.Tensor, top: int,