Skip to content

Commit 1b000ff

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Make RandomIoUCrop compatible with SanitizeBoundingBoxes (#7268)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: vmoens Differential Revision: D44416643 fbshipit-source-id: b2e1ec5f1a9fa1edc2bb5a459606bcf30c9dbeb0
1 parent 90e01c2 commit 1b000ff

File tree

4 files changed

+47
-42
lines changed

4 files changed

+47
-42
lines changed

test/test_prototype_transforms.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -1488,16 +1488,13 @@ def test__transform(self, mocker):
14881488

14891489
fn.assert_has_calls(expected_calls)
14901490

1491-
expected_within_targets = sum(is_within_crop_area)
1492-
14931491
# check number of bboxes vs number of labels:
14941492
output_bboxes = output[1]
14951493
assert isinstance(output_bboxes, datapoints.BoundingBox)
1496-
assert len(output_bboxes) == expected_within_targets
1494+
assert (output_bboxes[~is_within_crop_area] == 0).all()
14971495

14981496
output_masks = output[2]
14991497
assert isinstance(output_masks, datapoints.Mask)
1500-
assert len(output_masks) == expected_within_targets
15011498

15021499

15031500
class TestScaleJitter:
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
22532250

22542251

22552252
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
2256-
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
22572253
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
22582254
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
2259-
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
2255+
@pytest.mark.parametrize("sanitize", (True, False))
2256+
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
2257+
torch.manual_seed(0)
22602258
if data_augmentation == "hflip":
22612259
t = [
22622260
transforms.RandomHorizontalFlip(p=1),
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
22902288
t = [
22912289
transforms.RandomPhotometricDistort(p=1),
22922290
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
2293-
# TODO: put back IoUCrop once we remove its hard requirement for Labels
2294-
# transforms.RandomIoUCrop(),
2291+
transforms.RandomIoUCrop(),
22952292
transforms.RandomHorizontalFlip(p=1),
22962293
to_tensor(),
22972294
transforms.ConvertImageDtype(torch.float),
22982295
]
22992296
elif data_augmentation == "ssdlite":
23002297
t = [
2301-
# TODO: put back IoUCrop once we remove its hard requirement for Labels
2302-
# transforms.RandomIoUCrop(),
2298+
transforms.RandomIoUCrop(),
23032299
transforms.RandomHorizontalFlip(p=1),
23042300
to_tensor(),
23052301
transforms.ConvertImageDtype(torch.float),
23062302
]
2303+
if sanitize:
2304+
t += [transforms.SanitizeBoundingBoxes()]
23072305
t = transforms.Compose(t)
23082306

23092307
num_boxes = 5
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
23172315
assert is_simple_tensor(image)
23182316

23192317
label = torch.randint(0, 10, size=(num_boxes,))
2320-
if label_type is list:
2321-
label = label.tolist()
23222318

2323-
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
23242319
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
23252320
boxes[:, 2:] += boxes[:, :2]
23262321
boxes = boxes.clamp(min=0, max=min(H, W))
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
23432338
assert isinstance(out["image"], datapoints.Image)
23442339
assert isinstance(out["label"], type(sample["label"]))
23452340

2346-
out["label"] = torch.tensor(out["label"])
2347-
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
2341+
num_boxes_expected = {
2342+
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
2343+
# doesn't remove them strictly speaking, it just marks some boxes as
2344+
# degenerate and those boxes will be later removed by
2345+
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
2346+
# param is True.
2347+
# Note that the values below are probably specific to the random seed
2348+
# set above (which is fine).
2349+
(True, "ssd"): 4,
2350+
(True, "ssdlite"): 4,
2351+
}.get((sanitize, data_augmentation), num_boxes)
2352+
2353+
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected
23482354

23492355

23502356
@pytest.mark.parametrize("min_size", (1, 10))
@@ -2377,20 +2383,23 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
23772383
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
23782384

23792385
boxes = torch.tensor(boxes)
2380-
labels = torch.arange(boxes.shape[-2])
2386+
labels = torch.arange(boxes.shape[0])
23812387

23822388
boxes = datapoints.BoundingBox(
23832389
boxes,
23842390
format=datapoints.BoundingBoxFormat.XYXY,
23852391
spatial_size=(H, W),
23862392
)
23872393

2394+
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
2395+
23882396
sample = {
23892397
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
23902398
"labels": labels,
23912399
"boxes": boxes,
23922400
"whatever": torch.rand(10),
23932401
"None": None,
2402+
"masks": masks,
23942403
}
23952404

23962405
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
24022411
assert out["labels"] is sample["labels"]
24032412
else:
24042413
assert isinstance(out["labels"], torch.Tensor)
2405-
assert out["boxes"].shape[:-1] == out["labels"].shape
2414+
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0]
24062415
# This works because we conveniently set labels to arange(num_boxes)
24072416
assert out["labels"].tolist() == valid_indices
24082417

test/test_prototype_transforms_consistency.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1090,13 +1090,16 @@ def make_datapoints(self, with_mask=True):
10901090
"t_ref, t, data_kwargs",
10911091
[
10921092
(det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
1093-
# FIXME: make
1094-
# v2_transforms.Compose([
1095-
# v2_transforms.RandomIoUCrop(),
1096-
# v2_transforms.SanitizeBoundingBoxes()
1097-
# ])
1098-
# work
1099-
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}),
1093+
(
1094+
det_transforms.RandomIoUCrop(),
1095+
v2_transforms.Compose(
1096+
[
1097+
v2_transforms.RandomIoUCrop(),
1098+
v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
1099+
]
1100+
),
1101+
{"with_mask": False},
1102+
),
11001103
(det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
11011104
(det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}),
11021105
(

torchvision/transforms/v2/_geometry.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
721721
if left == right or top == bottom:
722722
continue
723723

724-
# FIXME: I think we can stop here?
725-
726724
# check for any valid boxes with centers within the crop area
727725
xyxy_bboxes = F.convert_format_bounding_box(
728726
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
@@ -745,23 +743,16 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
745743
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
746744

747745
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
748-
# FIXME: refactor this to not remove anything
749746

750747
if len(params) < 1:
751748
return inpt
752749

753-
is_within_crop_area = params["is_within_crop_area"]
754-
755750
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
756751

757752
if isinstance(output, datapoints.BoundingBox):
758-
bboxes = output[is_within_crop_area]
759-
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
760-
output = datapoints.BoundingBox.wrap_like(output, bboxes)
761-
elif isinstance(output, datapoints.Mask):
762-
# apply is_within_crop_area if mask is one-hot encoded
763-
masks = output[is_within_crop_area]
764-
output = datapoints.Mask.wrap_like(output, masks)
753+
# We "mark" the invalid boxes as degenreate, and they can be
754+
# removed by a later call to SanitizeBoundingBoxes()
755+
output[~params["is_within_crop_area"]] = 0
765756

766757
return output
767758

torchvision/transforms/v2/_misc.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,14 @@ def forward(self, *inputs: Any) -> Any:
265265
),
266266
)
267267
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
268-
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
268+
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
269269
# TODO: Do we really need to check for out of bounds here? All
270270
# transforms should be clamping anyway, so this should never happen?
271271
image_h, image_w = boxes.spatial_size
272-
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
273-
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
272+
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
273+
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
274274

275-
params = dict(mask=mask, labels=labels)
275+
params = dict(valid=valid, labels=labels)
276276
flat_outputs = [
277277
# Even-though it may look like we're transforming all inputs, we don't:
278278
# _transform() will only care about BoundingBoxes and the labels
@@ -284,7 +284,9 @@ def forward(self, *inputs: Any) -> Any:
284284

285285
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
286286

287-
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
288-
inpt = inpt[params["mask"]]
287+
if (inpt is not None and inpt is params["labels"]) or isinstance(
288+
inpt, (datapoints.BoundingBox, datapoints.Mask)
289+
):
290+
inpt = inpt[params["valid"]]
289291

290292
return inpt

0 commit comments

Comments
 (0)