Skip to content

Commit b525ffa

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Allow K=1 in draw_keypoints (#8439)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D58283863 fbshipit-source-id: d461dacb65272f4d8b477008ff1c8d33d3bd1141
1 parent 27810e8 commit b525ffa

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

test/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla():
355355
assert_equal(img, img_cp)
356356

357357

358+
def test_draw_keypoins_K_equals_one():
359+
# Non-regression test for https://github.com/pytorch/vision/pull/8439
360+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
361+
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
362+
utils.draw_keypoints(img, keypoints)
363+
364+
358365
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
359366
def test_draw_keypoints_colored(colors):
360367
# Keypoints is declared on top as global variable

torchvision/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,10 @@ def draw_keypoints(
392392
# validate visibility
393393
if visibility is None: # set default
394394
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
395-
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
396-
# model, make sure visibility has shape (num_instances, K).
397-
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
398-
visibility = visibility.squeeze(-1)
395+
if visibility.ndim == 3:
396+
# If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1).
397+
# We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice.
398+
visibility = visibility.squeeze(-1)
399399
if visibility.ndim != 2:
400400
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
401401
if visibility.shape != keypoints.shape[:-1]:

0 commit comments

Comments
 (0)