@@ -200,7 +200,10 @@ def draw_bounding_boxes(
200
200
raise ValueError ("Only grayscale and RGB images are supported" )
201
201
202
202
num_boxes = boxes .shape [0 ]
203
- if labels and len (labels ) != num_boxes :
203
+
204
+ if labels is None :
205
+ labels = [None ] * num_boxes
206
+ elif len (labels ) != num_boxes :
204
207
raise ValueError (
205
208
f"Number of boxes ({ num_boxes } ) and labels ({ len (labels )} ) mismatch. Please specify labels for each box."
206
209
)
@@ -230,16 +233,16 @@ def draw_bounding_boxes(
230
233
231
234
txt_font = ImageFont .load_default () if font is None else ImageFont .truetype (font = font , size = font_size )
232
235
233
- for i , ( bbox , color ) in enumerate ( zip (img_boxes , colors ) ):
236
+ for bbox , color , label in zip (img_boxes , colors , labels ):
234
237
if fill :
235
238
fill_color = color + (100 ,)
236
239
draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
237
240
else :
238
241
draw .rectangle (bbox , width = width , outline = color )
239
242
240
- if labels is not None :
243
+ if label is not None :
241
244
margin = width + 1
242
- draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), labels [ i ] , fill = color , font = txt_font )
245
+ draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = color , font = txt_font )
243
246
244
247
return torch .from_numpy (np .array (img_to_draw )).permute (2 , 0 , 1 ).to (dtype = torch .uint8 )
245
248
0 commit comments