Skip to content

Commit 2cadbda

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Updated tests for CopyPaste on OneHotLabel (#6485)
Summary: * [proto] Updated tests for CopyPaste on OneHotLabel * Fixing test error Reviewed By: NicolasHug Differential Revision: D39131009 fbshipit-source-id: 81585873349816cb7144b8c1082b96b8acb36cae
1 parent 9af05ac commit 2cadbda

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

test/test_prototype_transforms.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,23 +1341,24 @@ def test__extract_image_targets_assertion(self, mocker):
13411341
mocker.MagicMock(spec=features.SegmentationMask),
13421342
]
13431343

1344-
with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
1344+
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
13451345
transform._extract_image_targets(flat_sample)
13461346

13471347
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
1348-
def test__extract_image_targets(self, image_type, mocker):
1348+
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
1349+
def test__extract_image_targets(self, image_type, label_type, mocker):
13491350
transform = transforms.SimpleCopyPaste()
13501351

13511352
flat_sample = [
13521353
# images, batch size = 2
13531354
self.create_fake_image(mocker, image_type),
13541355
self.create_fake_image(mocker, image_type),
13551356
# labels, bboxes, masks
1356-
mocker.MagicMock(spec=features.Label),
1357+
mocker.MagicMock(spec=label_type),
13571358
mocker.MagicMock(spec=features.BoundingBox),
13581359
mocker.MagicMock(spec=features.SegmentationMask),
13591360
# labels, bboxes, masks
1360-
mocker.MagicMock(spec=features.Label),
1361+
mocker.MagicMock(spec=label_type),
13611362
mocker.MagicMock(spec=features.BoundingBox),
13621363
mocker.MagicMock(spec=features.SegmentationMask),
13631364
]
@@ -1372,29 +1373,46 @@ def test__extract_image_targets(self, image_type, mocker):
13721373
assert images[0] == flat_sample[0]
13731374
assert images[1] == flat_sample[1]
13741375

1375-
def test__copy_paste(self):
1376+
for target in targets:
1377+
for key, type_ in [
1378+
("boxes", features.BoundingBox),
1379+
("masks", features.SegmentationMask),
1380+
("labels", label_type),
1381+
]:
1382+
assert key in target
1383+
assert isinstance(target[key], type_)
1384+
assert target[key] in flat_sample
1385+
1386+
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
1387+
def test__copy_paste(self, label_type):
13761388
image = 2 * torch.ones(3, 32, 32)
13771389
masks = torch.zeros(2, 32, 32)
13781390
masks[0, 3:9, 2:8] = 1
13791391
masks[1, 20:30, 20:30] = 1
1392+
labels = torch.tensor([1, 2])
1393+
if label_type == features.OneHotLabel:
1394+
labels = torch.nn.functional.one_hot(labels, num_classes=5)
13801395
target = {
13811396
"boxes": features.BoundingBox(
13821397
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
13831398
),
13841399
"masks": features.SegmentationMask(masks),
1385-
"labels": features.Label(torch.tensor([1, 2])),
1400+
"labels": label_type(labels),
13861401
}
13871402

13881403
paste_image = 10 * torch.ones(3, 32, 32)
13891404
paste_masks = torch.zeros(2, 32, 32)
13901405
paste_masks[0, 13:19, 12:18] = 1
13911406
paste_masks[1, 15:19, 1:8] = 1
1407+
paste_labels = torch.tensor([3, 4])
1408+
if label_type == features.OneHotLabel:
1409+
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
13921410
paste_target = {
13931411
"boxes": features.BoundingBox(
13941412
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
13951413
),
13961414
"masks": features.SegmentationMask(paste_masks),
1397-
"labels": features.Label(torch.tensor([3, 4])),
1415+
"labels": label_type(paste_labels),
13981416
}
13991417

14001418
transform = transforms.SimpleCopyPaste()
@@ -1405,7 +1423,12 @@ def test__copy_paste(self):
14051423
assert output_target["boxes"].shape == (4, 4)
14061424
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
14071425
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
1408-
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
1426+
1427+
expected_labels = torch.tensor([1, 2, 3, 4])
1428+
if label_type == features.OneHotLabel:
1429+
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
1430+
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
1431+
14091432
assert output_target["masks"].shape == (4, 32, 32)
14101433
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
14111434
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
288288

289289
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
290290
raise TypeError(
291-
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
291+
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
292292
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
293293
)
294294

0 commit comments

Comments
 (0)