Skip to content

Commit fcba00a

Browse files
committed
Actually simplify further
1 parent 4c95a5c commit fcba00a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchvision/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ def draw_bounding_boxes(
200200
raise ValueError("Only grayscale and RGB images are supported")
201201

202202
num_boxes = boxes.shape[0]
203-
if labels and len(labels) != num_boxes:
203+
204+
if labels is None:
205+
labels = [None] * num_boxes
206+
elif len(labels) != num_boxes:
204207
raise ValueError(
205208
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
206209
)
@@ -230,16 +233,16 @@ def draw_bounding_boxes(
230233

231234
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
232235

233-
for i, (bbox, color) in enumerate(zip(img_boxes, colors)):
236+
for bbox, color, label in zip(img_boxes, colors, labels):
234237
if fill:
235238
fill_color = color + (100,)
236239
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
237240
else:
238241
draw.rectangle(bbox, width=width, outline=color)
239242

240-
if labels is not None:
243+
if label is not None:
241244
margin = width + 1
242-
draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font)
245+
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
243246

244247
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
245248

0 commit comments

Comments
 (0)