@@ -176,6 +176,7 @@ def draw_bounding_boxes(
176
176
colors (color or list of colors, optional): List containing the colors
177
177
of the boxes or single color for all boxes. The color can be represented as
178
178
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
179
+ By default, random colors are generated for boxes.
179
180
fill (bool): If `True` fills the bounding box with specified color.
180
181
width (int): Width of bounding box.
181
182
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -198,45 +199,50 @@ def draw_bounding_boxes(
198
199
elif image .size (0 ) not in {1 , 3 }:
199
200
raise ValueError ("Only grayscale and RGB images are supported" )
200
201
202
+ num_boxes = boxes .shape [0 ]
203
+
204
+ if labels is None :
205
+ labels : Union [List [str ], List [None ]] = [None ] * num_boxes # type: ignore[no-redef]
206
+ elif len (labels ) != num_boxes :
207
+ raise ValueError (
208
+ f"Number of boxes ({ num_boxes } ) and labels ({ len (labels )} ) mismatch. Please specify labels for each box."
209
+ )
210
+
211
+ if colors is None :
212
+ colors = _generate_color_palette (num_boxes )
213
+ elif isinstance (colors , list ):
214
+ if len (colors ) < num_boxes :
215
+ raise ValueError (f"Number of colors ({ len (colors )} ) is less than number of boxes ({ num_boxes } ). " )
216
+ else : # colors specifies a single color for all boxes
217
+ colors = [colors ] * num_boxes
218
+
219
+ colors = [(ImageColor .getrgb (color ) if isinstance (color , str ) else color ) for color in colors ]
220
+
221
+ # Handle Grayscale images
201
222
if image .size (0 ) == 1 :
202
223
image = torch .tile (image , (3 , 1 , 1 ))
203
224
204
225
ndarr = image .permute (1 , 2 , 0 ).cpu ().numpy ()
205
226
img_to_draw = Image .fromarray (ndarr )
206
-
207
227
img_boxes = boxes .to (torch .int64 ).tolist ()
208
228
209
229
if fill :
210
230
draw = ImageDraw .Draw (img_to_draw , "RGBA" )
211
-
212
231
else :
213
232
draw = ImageDraw .Draw (img_to_draw )
214
233
215
234
txt_font = ImageFont .load_default () if font is None else ImageFont .truetype (font = font , size = font_size )
216
235
217
- for i , bbox in enumerate (img_boxes ):
218
- if colors is None :
219
- color = None
220
- elif isinstance (colors , list ):
221
- color = colors [i ]
222
- else :
223
- color = colors
224
-
236
+ for bbox , color , label in zip (img_boxes , colors , labels ): # type: ignore[arg-type]
225
237
if fill :
226
- if color is None :
227
- fill_color = (255 , 255 , 255 , 100 )
228
- elif isinstance (color , str ):
229
- # This will automatically raise Error if rgb cannot be parsed.
230
- fill_color = ImageColor .getrgb (color ) + (100 ,)
231
- elif isinstance (color , tuple ):
232
- fill_color = color + (100 ,)
238
+ fill_color = color + (100 ,)
233
239
draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
234
240
else :
235
241
draw .rectangle (bbox , width = width , outline = color )
236
242
237
- if labels is not None :
243
+ if label is not None :
238
244
margin = width + 1
239
- draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), labels [ i ] , fill = color , font = txt_font )
245
+ draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = color , font = txt_font )
240
246
241
247
return torch .from_numpy (np .array (img_to_draw )).permute (2 , 0 , 1 ).to (dtype = torch .uint8 )
242
248
@@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor:
505
511
return colorwheel
506
512
507
513
508
- def _generate_color_palette (num_masks : int ):
514
+ def _generate_color_palette (num_objects : int ):
509
515
palette = torch .tensor ([2 ** 25 - 1 , 2 ** 15 - 1 , 2 ** 21 - 1 ])
510
- return [tuple ((i * palette ) % 255 ) for i in range (num_masks )]
516
+ return [tuple ((i * palette ) % 255 ) for i in range (num_objects )]
511
517
512
518
513
519
def _log_api_usage_once (obj : Any ) -> None :
0 commit comments