@@ -1341,23 +1341,24 @@ def test__extract_image_targets_assertion(self, mocker):
1341
1341
mocker .MagicMock (spec = features .SegmentationMask ),
1342
1342
]
1343
1343
1344
- with pytest .raises (TypeError , match = "requires input sample to contain equal- sized list of Images" ):
1344
+ with pytest .raises (TypeError , match = "requires input sample to contain equal sized list of Images" ):
1345
1345
transform ._extract_image_targets (flat_sample )
1346
1346
1347
1347
@pytest .mark .parametrize ("image_type" , [features .Image , PIL .Image .Image , torch .Tensor ])
1348
- def test__extract_image_targets (self , image_type , mocker ):
1348
+ @pytest .mark .parametrize ("label_type" , [features .Label , features .OneHotLabel ])
1349
+ def test__extract_image_targets (self , image_type , label_type , mocker ):
1349
1350
transform = transforms .SimpleCopyPaste ()
1350
1351
1351
1352
flat_sample = [
1352
1353
# images, batch size = 2
1353
1354
self .create_fake_image (mocker , image_type ),
1354
1355
self .create_fake_image (mocker , image_type ),
1355
1356
# labels, bboxes, masks
1356
- mocker .MagicMock (spec = features . Label ),
1357
+ mocker .MagicMock (spec = label_type ),
1357
1358
mocker .MagicMock (spec = features .BoundingBox ),
1358
1359
mocker .MagicMock (spec = features .SegmentationMask ),
1359
1360
# labels, bboxes, masks
1360
- mocker .MagicMock (spec = features . Label ),
1361
+ mocker .MagicMock (spec = label_type ),
1361
1362
mocker .MagicMock (spec = features .BoundingBox ),
1362
1363
mocker .MagicMock (spec = features .SegmentationMask ),
1363
1364
]
@@ -1372,29 +1373,46 @@ def test__extract_image_targets(self, image_type, mocker):
1372
1373
assert images [0 ] == flat_sample [0 ]
1373
1374
assert images [1 ] == flat_sample [1 ]
1374
1375
1375
- def test__copy_paste (self ):
1376
+ for target in targets :
1377
+ for key , type_ in [
1378
+ ("boxes" , features .BoundingBox ),
1379
+ ("masks" , features .SegmentationMask ),
1380
+ ("labels" , label_type ),
1381
+ ]:
1382
+ assert key in target
1383
+ assert isinstance (target [key ], type_ )
1384
+ assert target [key ] in flat_sample
1385
+
1386
+ @pytest .mark .parametrize ("label_type" , [features .Label , features .OneHotLabel ])
1387
+ def test__copy_paste (self , label_type ):
1376
1388
image = 2 * torch .ones (3 , 32 , 32 )
1377
1389
masks = torch .zeros (2 , 32 , 32 )
1378
1390
masks [0 , 3 :9 , 2 :8 ] = 1
1379
1391
masks [1 , 20 :30 , 20 :30 ] = 1
1392
+ labels = torch .tensor ([1 , 2 ])
1393
+ if label_type == features .OneHotLabel :
1394
+ labels = torch .nn .functional .one_hot (labels , num_classes = 5 )
1380
1395
target = {
1381
1396
"boxes" : features .BoundingBox (
1382
1397
torch .tensor ([[2.0 , 3.0 , 8.0 , 9.0 ], [20.0 , 20.0 , 30.0 , 30.0 ]]), format = "XYXY" , image_size = (32 , 32 )
1383
1398
),
1384
1399
"masks" : features .SegmentationMask (masks ),
1385
- "labels" : features . Label ( torch . tensor ([ 1 , 2 ]) ),
1400
+ "labels" : label_type ( labels ),
1386
1401
}
1387
1402
1388
1403
paste_image = 10 * torch .ones (3 , 32 , 32 )
1389
1404
paste_masks = torch .zeros (2 , 32 , 32 )
1390
1405
paste_masks [0 , 13 :19 , 12 :18 ] = 1
1391
1406
paste_masks [1 , 15 :19 , 1 :8 ] = 1
1407
+ paste_labels = torch .tensor ([3 , 4 ])
1408
+ if label_type == features .OneHotLabel :
1409
+ paste_labels = torch .nn .functional .one_hot (paste_labels , num_classes = 5 )
1392
1410
paste_target = {
1393
1411
"boxes" : features .BoundingBox (
1394
1412
torch .tensor ([[12.0 , 13.0 , 19.0 , 18.0 ], [1.0 , 15.0 , 8.0 , 19.0 ]]), format = "XYXY" , image_size = (32 , 32 )
1395
1413
),
1396
1414
"masks" : features .SegmentationMask (paste_masks ),
1397
- "labels" : features . Label ( torch . tensor ([ 3 , 4 ]) ),
1415
+ "labels" : label_type ( paste_labels ),
1398
1416
}
1399
1417
1400
1418
transform = transforms .SimpleCopyPaste ()
@@ -1405,7 +1423,12 @@ def test__copy_paste(self):
1405
1423
assert output_target ["boxes" ].shape == (4 , 4 )
1406
1424
torch .testing .assert_close (output_target ["boxes" ][:2 , :], target ["boxes" ])
1407
1425
torch .testing .assert_close (output_target ["boxes" ][2 :, :], paste_target ["boxes" ])
1408
- torch .testing .assert_close (output_target ["labels" ], features .Label (torch .tensor ([1 , 2 , 3 , 4 ])))
1426
+
1427
+ expected_labels = torch .tensor ([1 , 2 , 3 , 4 ])
1428
+ if label_type == features .OneHotLabel :
1429
+ expected_labels = torch .nn .functional .one_hot (expected_labels , num_classes = 5 )
1430
+ torch .testing .assert_close (output_target ["labels" ], label_type (expected_labels ))
1431
+
1409
1432
assert output_target ["masks" ].shape == (4 , 32 , 32 )
1410
1433
torch .testing .assert_close (output_target ["masks" ][:2 , :], target ["masks" ])
1411
1434
torch .testing .assert_close (output_target ["masks" ][2 :, :], paste_target ["masks" ])
0 commit comments