Skip to content

Commit 5486b76

Browse files
oxabzLEGRAND MatthieuNicolasHug
authored
Throw warning for empty masks or box tensors on draw_segmentation_masks and draw_bounding_boxes (#5857)
* Fixing the IndexError in draw_segmentation_masks * fixing the bug on draw_bounding_boxes * Changing fstring to normal string * Removing unecessary conversion * Adding test for the change * Adding a test for draw seqmentation mask * Fixing small mistake * Fixing an error in the tests * removing useless imports * ufmt Co-authored-by: LEGRAND Matthieu <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent b969cca commit 5486b76

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/test_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ def test_draw_boxes_warning():
176176
utils.draw_bounding_boxes(img, boxes, font_size=11)
177177

178178

179+
def test_draw_no_boxes():
180+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
181+
boxes = torch.full((0, 4), 0, dtype=torch.float)
182+
with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
183+
res = utils.draw_bounding_boxes(img, boxes)
184+
# Check that the function didnt change the image
185+
assert res.eq(img).all()
186+
187+
179188
@pytest.mark.parametrize(
180189
"colors",
181190
[
@@ -266,6 +275,15 @@ def test_draw_segmentation_masks_errors():
266275
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
267276

268277

278+
def test_draw_no_segmention_mask():
279+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
280+
masks = torch.full((0, 100, 100), 0, dtype=torch.bool)
281+
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
282+
res = utils.draw_segmentation_masks(img, masks)
283+
# Check that the function didnt change the image
284+
assert res.eq(img).all()
285+
286+
269287
def test_draw_keypoints_vanilla():
270288
# Keypoints is declared on top as global variable
271289
keypoints_cp = keypoints.clone()

torchvision/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ def draw_bounding_boxes(
211211

212212
num_boxes = boxes.shape[0]
213213

214+
if num_boxes == 0:
215+
warnings.warn("boxes doesn't contain any box. No box was drawn")
216+
return image
217+
214218
if labels is None:
215219
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
216220
elif len(labels) != num_boxes:
@@ -311,6 +315,10 @@ def draw_segmentation_masks(
311315
if colors is not None and num_masks > len(colors):
312316
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
313317

318+
if num_masks == 0:
319+
warnings.warn("masks doesn't contain any mask. No mask was drawn")
320+
return image
321+
314322
if colors is None:
315323
colors = _generate_color_palette(num_masks)
316324

0 commit comments

Comments
 (0)