@@ -1488,16 +1488,13 @@ def test__transform(self, mocker):
1488
1488
1489
1489
fn .assert_has_calls (expected_calls )
1490
1490
1491
- expected_within_targets = sum (is_within_crop_area )
1492
-
1493
1491
# check number of bboxes vs number of labels:
1494
1492
output_bboxes = output [1 ]
1495
1493
assert isinstance (output_bboxes , datapoints .BoundingBox )
1496
- assert len (output_bboxes ) == expected_within_targets
1494
+ assert (output_bboxes [ ~ is_within_crop_area ] == 0 ). all ()
1497
1495
1498
1496
output_masks = output [2 ]
1499
1497
assert isinstance (output_masks , datapoints .Mask )
1500
- assert len (output_masks ) == expected_within_targets
1501
1498
1502
1499
1503
1500
class TestScaleJitter :
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
2253
2250
2254
2251
2255
2252
@pytest .mark .parametrize ("image_type" , (PIL .Image , torch .Tensor , datapoints .Image ))
2256
- @pytest .mark .parametrize ("label_type" , (torch .Tensor , list ))
2257
2253
@pytest .mark .parametrize ("data_augmentation" , ("hflip" , "lsj" , "multiscale" , "ssd" , "ssdlite" ))
2258
2254
@pytest .mark .parametrize ("to_tensor" , (transforms .ToTensor , transforms .ToImageTensor ))
2259
- def test_detection_preset (image_type , label_type , data_augmentation , to_tensor ):
2255
+ @pytest .mark .parametrize ("sanitize" , (True , False ))
2256
+ def test_detection_preset (image_type , data_augmentation , to_tensor , sanitize ):
2257
+ torch .manual_seed (0 )
2260
2258
if data_augmentation == "hflip" :
2261
2259
t = [
2262
2260
transforms .RandomHorizontalFlip (p = 1 ),
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
2290
2288
t = [
2291
2289
transforms .RandomPhotometricDistort (p = 1 ),
2292
2290
transforms .RandomZoomOut (fill = defaultdict (lambda : (123.0 , 117.0 , 104.0 ), {datapoints .Mask : 0 })),
2293
- # TODO: put back IoUCrop once we remove its hard requirement for Labels
2294
- # transforms.RandomIoUCrop(),
2291
+ transforms .RandomIoUCrop (),
2295
2292
transforms .RandomHorizontalFlip (p = 1 ),
2296
2293
to_tensor (),
2297
2294
transforms .ConvertImageDtype (torch .float ),
2298
2295
]
2299
2296
elif data_augmentation == "ssdlite" :
2300
2297
t = [
2301
- # TODO: put back IoUCrop once we remove its hard requirement for Labels
2302
- # transforms.RandomIoUCrop(),
2298
+ transforms .RandomIoUCrop (),
2303
2299
transforms .RandomHorizontalFlip (p = 1 ),
2304
2300
to_tensor (),
2305
2301
transforms .ConvertImageDtype (torch .float ),
2306
2302
]
2303
+ if sanitize :
2304
+ t += [transforms .SanitizeBoundingBoxes ()]
2307
2305
t = transforms .Compose (t )
2308
2306
2309
2307
num_boxes = 5
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
2317
2315
assert is_simple_tensor (image )
2318
2316
2319
2317
label = torch .randint (0 , 10 , size = (num_boxes ,))
2320
- if label_type is list :
2321
- label = label .tolist ()
2322
2318
2323
- # TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
2324
2319
boxes = torch .randint (0 , min (H , W ) // 2 , size = (num_boxes , 4 ))
2325
2320
boxes [:, 2 :] += boxes [:, :2 ]
2326
2321
boxes = boxes .clamp (min = 0 , max = min (H , W ))
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
2343
2338
assert isinstance (out ["image" ], datapoints .Image )
2344
2339
assert isinstance (out ["label" ], type (sample ["label" ]))
2345
2340
2346
- out ["label" ] = torch .tensor (out ["label" ])
2347
- assert out ["boxes" ].shape [0 ] == out ["masks" ].shape [0 ] == out ["label" ].shape [0 ] == num_boxes
2341
+ num_boxes_expected = {
2342
+ # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
2343
+ # doesn't remove them strictly speaking, it just marks some boxes as
2344
+ # degenerate and those boxes will be later removed by
2345
+ # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
2346
+ # param is True.
2347
+ # Note that the values below are probably specific to the random seed
2348
+ # set above (which is fine).
2349
+ (True , "ssd" ): 4 ,
2350
+ (True , "ssdlite" ): 4 ,
2351
+ }.get ((sanitize , data_augmentation ), num_boxes )
2352
+
2353
+ assert out ["boxes" ].shape [0 ] == out ["masks" ].shape [0 ] == out ["label" ].shape [0 ] == num_boxes_expected
2348
2354
2349
2355
2350
2356
@pytest .mark .parametrize ("min_size" , (1 , 10 ))
@@ -2377,20 +2383,23 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
2377
2383
valid_indices = [i for (i , is_valid ) in enumerate (is_valid_mask ) if is_valid ]
2378
2384
2379
2385
boxes = torch .tensor (boxes )
2380
- labels = torch .arange (boxes .shape [- 2 ])
2386
+ labels = torch .arange (boxes .shape [0 ])
2381
2387
2382
2388
boxes = datapoints .BoundingBox (
2383
2389
boxes ,
2384
2390
format = datapoints .BoundingBoxFormat .XYXY ,
2385
2391
spatial_size = (H , W ),
2386
2392
)
2387
2393
2394
+ masks = datapoints .Mask (torch .randint (0 , 2 , size = (boxes .shape [0 ], H , W )))
2395
+
2388
2396
sample = {
2389
2397
"image" : torch .randint (0 , 256 , size = (1 , 3 , H , W ), dtype = torch .uint8 ),
2390
2398
"labels" : labels ,
2391
2399
"boxes" : boxes ,
2392
2400
"whatever" : torch .rand (10 ),
2393
2401
"None" : None ,
2402
+ "masks" : masks ,
2394
2403
}
2395
2404
2396
2405
out = transforms .SanitizeBoundingBoxes (min_size = min_size , labels_getter = labels_getter )(sample )
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
2402
2411
assert out ["labels" ] is sample ["labels" ]
2403
2412
else :
2404
2413
assert isinstance (out ["labels" ], torch .Tensor )
2405
- assert out ["boxes" ].shape [: - 1 ] == out ["labels" ].shape
2414
+ assert out ["boxes" ].shape [0 ] == out ["labels" ].shape [ 0 ] == out [ "masks" ]. shape [ 0 ]
2406
2415
# This works because we conveniently set labels to arange(num_boxes)
2407
2416
assert out ["labels" ].tolist () == valid_indices
2408
2417
0 commit comments