Skip to content

Commit a75dc89

Browse files
Fix annotation of draw_segmentation_masks (#4527)
* Add str param * Update test to include str * Fix mypy * Remove a small bracket * Test more robustly * Update docstring and test: * Apply suggestions from code review Co-authored-by: Nicolas Hug <[email protected]> * Update torchvision/utils.py Small docstring fix * Update torchvision/utils.py * remove unnecessary renaming Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 4d711fd commit a75dc89

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

test/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def test_draw_invalid_boxes():
162162
"colors",
163163
[
164164
None,
165+
"blue",
166+
"#FF00FF",
167+
(1, 34, 122),
165168
["red", "blue"],
166169
["#FF00FF", (1, 34, 122)],
167170
],
@@ -191,6 +194,8 @@ def test_draw_segmentation_masks(colors, alpha):
191194

192195
if colors is None:
193196
colors = utils._generate_color_palette(num_masks)
197+
elif isinstance(colors, str) or isinstance(colors, tuple):
198+
colors = [colors]
194199

195200
# Make sure each mask draws with its own color
196201
for mask, color in zip(masks, colors):

torchvision/utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def draw_bounding_boxes(
160160
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
161161
`0 <= ymin < ymax < H`.
162162
labels (List[str]): List containing the labels of bounding boxes.
163-
colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors
164-
or a single color for all of the bounding boxes. The colors can be represented as `str` or
165-
`Tuple[int, int, int]`.
163+
colors (color or list of colors, optional): List containing the colors
164+
of the boxes or single color for all boxes. The color can be represented as
165+
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
166166
fill (bool): If `True` fills the bounding box with specified color.
167167
width (int): Width of bounding box.
168168
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -231,7 +231,7 @@ def draw_segmentation_masks(
231231
image: torch.Tensor,
232232
masks: torch.Tensor,
233233
alpha: float = 0.8,
234-
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
234+
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
235235
) -> torch.Tensor:
236236

237237
"""
@@ -243,10 +243,10 @@ def draw_segmentation_masks(
243243
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
244244
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
245245
0 means full transparency, 1 means no transparency.
246-
colors (list or None): List containing the colors of the masks. The colors can
247-
be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
248-
When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
249-
with one element. By default, random colors are generated for each mask.
246+
colors (color or list of colors, optional): List containing the colors
247+
of the masks or single color for all masks. The color can be represented as
248+
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
249+
By default, random colors are generated for each mask.
250250
251251
Returns:
252252
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
@@ -289,8 +289,7 @@ def draw_segmentation_masks(
289289
for color in colors:
290290
if isinstance(color, str):
291291
color = ImageColor.getrgb(color)
292-
color = torch.tensor(color, dtype=out_dtype)
293-
colors_.append(color)
292+
colors_.append(torch.tensor(color, dtype=out_dtype))
294293

295294
img_to_draw = image.detach().clone()
296295
# TODO: There might be a way to vectorize this
@@ -301,6 +300,6 @@ def draw_segmentation_masks(
301300
return out.to(out_dtype)
302301

303302

304-
def _generate_color_palette(num_masks):
303+
def _generate_color_palette(num_masks: int):
305304
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
306305
return [tuple((i * palette) % 255) for i in range(num_masks)]

0 commit comments

Comments
 (0)