Skip to content

Commit 2ba2f1d

Browse files
authored
[proto] Speed-up crop on bboxes and tests (#6881)
* [proto] Speed-up crop on bboxes and tests * Fix linter * Update _geometry.py * Fixed device issue * Revert changes in test/prototype_transforms_kernel_infos.py * Fixed failing correctness tests
1 parent 1921613 commit 2ba2f1d

File tree

3 files changed

+36
-22
lines changed

3 files changed

+36
-22
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,27 @@ def sample_inputs_crop_video():
862862
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
863863

864864

865+
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
866+
867+
affine_matrix = np.array(
868+
[
869+
[1, 0, -left],
870+
[0, 1, -top],
871+
],
872+
dtype="float32",
873+
)
874+
875+
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
876+
return expected_bboxes, (height, width)
877+
878+
879+
def reference_inputs_crop_bounding_box():
880+
for bounding_box_loader, params in itertools.product(
881+
make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
882+
):
883+
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
884+
885+
865886
KERNEL_INFOS.extend(
866887
[
867888
KernelInfo(
@@ -875,6 +896,8 @@ def sample_inputs_crop_video():
875896
KernelInfo(
876897
F.crop_bounding_box,
877898
sample_inputs_fn=sample_inputs_crop_bounding_box,
899+
reference_fn=reference_crop_bounding_box,
900+
reference_inputs_fn=reference_inputs_crop_bounding_box,
878901
),
879902
KernelInfo(
880903
F.crop_mask,

test/test_prototype_transforms_functional.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
900900
def _compute_expected_bbox(bbox, output_size_):
901901
format_ = bbox.format
902902
spatial_size_ = bbox.spatial_size
903-
bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH)
903+
dtype = bbox.dtype
904+
bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH)
904905

905906
if len(output_size_) == 1:
906907
output_size_.append(output_size_[-1])
@@ -913,14 +914,9 @@ def _compute_expected_bbox(bbox, output_size_):
913914
bbox[2].item(),
914915
bbox[3].item(),
915916
]
916-
out_bbox = features.BoundingBox(
917-
out_bbox,
918-
format=features.BoundingBoxFormat.XYWH,
919-
spatial_size=output_size_,
920-
dtype=bbox.dtype,
921-
device=bbox.device,
922-
)
923-
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
917+
out_bbox = torch.tensor(out_bbox)
918+
out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
919+
return out_bbox.to(dtype=dtype, device=bbox.device)
924920

925921
for bboxes in make_bounding_boxes(extra_dims=((4,),)):
926922
bboxes = bboxes.to(device)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -802,22 +802,17 @@ def crop_bounding_box(
802802
height: int,
803803
width: int,
804804
) -> Tuple[torch.Tensor, Tuple[int, int]]:
805-
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
806-
# BoundingBoxFormat instead of converting back and forth
807-
bounding_box = convert_format_bounding_box(
808-
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
809-
)
805+
806+
bounding_box = bounding_box.clone()
810807

811808
# Crop or implicit pad if left and/or top have negative values:
812-
bounding_box[..., 0::2] -= left
813-
bounding_box[..., 1::2] -= top
809+
if format == features.BoundingBoxFormat.XYXY:
810+
sub = torch.tensor([left, top, left, top], device=bounding_box.device)
811+
else:
812+
sub = torch.tensor([left, top, 0, 0], device=bounding_box.device)
813+
bounding_box = bounding_box.sub_(sub)
814814

815-
return (
816-
convert_format_bounding_box(
817-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
818-
),
819-
(height, width),
820-
)
815+
return bounding_box, (height, width)
821816

822817

823818
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:

0 commit comments

Comments
 (0)