@@ -160,9 +160,9 @@ def draw_bounding_boxes(
160
160
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
161
161
`0 <= ymin < ymax < H`.
162
162
labels (List[str]): List containing the labels of bounding boxes.
163
- colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]] ): List containing the colors
164
- or a single color for all of the bounding boxes. The colors can be represented as `str` or
165
- `Tuple[int, int, int] `.
163
+ colors (color or list of colors, optional ): List containing the colors
164
+ of the boxes or single color for all boxes. The color can be represented as
165
+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)` `.
166
166
fill (bool): If `True` fills the bounding box with specified color.
167
167
width (int): Width of bounding box.
168
168
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -231,7 +231,7 @@ def draw_segmentation_masks(
231
231
image : torch .Tensor ,
232
232
masks : torch .Tensor ,
233
233
alpha : float = 0.8 ,
234
- colors : Optional [List [Union [str , Tuple [int , int , int ]]]] = None ,
234
+ colors : Optional [Union [ List [Union [str , Tuple [int , int , int ]]], str , Tuple [ int , int , int ]]] = None ,
235
235
) -> torch .Tensor :
236
236
237
237
"""
@@ -243,10 +243,10 @@ def draw_segmentation_masks(
243
243
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
244
244
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
245
245
0 means full transparency, 1 means no transparency.
246
- colors (list or None ): List containing the colors of the masks. The colors can
247
- be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
248
- When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
249
- with one element. By default, random colors are generated for each mask.
246
+ colors (color or list of colors, optional ): List containing the colors
247
+ of the masks or single color for all masks. The color can be represented as
248
+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
249
+ By default, random colors are generated for each mask.
250
250
251
251
Returns:
252
252
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
@@ -289,8 +289,7 @@ def draw_segmentation_masks(
289
289
for color in colors :
290
290
if isinstance (color , str ):
291
291
color = ImageColor .getrgb (color )
292
- color = torch .tensor (color , dtype = out_dtype )
293
- colors_ .append (color )
292
+ colors_ .append (torch .tensor (color , dtype = out_dtype ))
294
293
295
294
img_to_draw = image .detach ().clone ()
296
295
# TODO: There might be a way to vectorize this
@@ -301,6 +300,6 @@ def draw_segmentation_masks(
301
300
return out .to (out_dtype )
302
301
303
302
304
- def _generate_color_palette (num_masks ):
303
+ def _generate_color_palette (num_masks : int ):
305
304
palette = torch .tensor ([2 ** 25 - 1 , 2 ** 15 - 1 , 2 ** 21 - 1 ])
306
305
return [tuple ((i * palette ) % 255 ) for i in range (num_masks )]
0 commit comments