diff --git a/test/test_utils.py b/test/test_utils.py index ac394b51d63..e89bef4a6d9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) +def test_draw_keypoins_K_equals_one(): + # Non-regression test for https://github.com/pytorch/vision/pull/8439 + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + keypoints = torch.tensor([[[10, 10]]], dtype=torch.float) + utils.draw_keypoints(img, keypoints) + + @pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable diff --git a/torchvision/utils.py b/torchvision/utils.py index 94b3ec65c87..6b2d19ec3dd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -392,10 +392,10 @@ def draw_keypoints( # validate visibility if visibility is None: # set default visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) - # If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction - # model, make sure visibility has shape (num_instances, K). - # Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place. - visibility = visibility.squeeze(-1) + if visibility.ndim == 3: + # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1). + # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice. + visibility = visibility.squeeze(-1) if visibility.ndim != 2: raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") if visibility.shape != keypoints.shape[:-1]: