Skip to content

Commit f799a53

Browse files
Add label background (#9018)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent ac7ad5f commit f799a53

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed
Loading

test/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ def test_draw_boxes_with_coloured_labels():
131131
assert_equal(result, expected)
132132

133133

134+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
135+
def test_draw_boxes_with_coloured_label_backgrounds():
136+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
137+
labels = ["a", "b", "c", "d"]
138+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
139+
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
140+
result = utils.draw_bounding_boxes(
141+
img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors, fill_labels=True
142+
)
143+
144+
path = os.path.join(
145+
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_fill_colors.png"
146+
)
147+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
148+
assert_equal(result, expected)
149+
150+
134151
@pytest.mark.parametrize("fill", [True, False])
135152
def test_draw_boxes_dtypes(fill):
136153
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

torchvision/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def draw_bounding_boxes(
162162
font: Optional[str] = None,
163163
font_size: Optional[int] = None,
164164
label_colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
165+
fill_labels: bool = False,
165166
) -> torch.Tensor:
166167

167168
"""
@@ -186,7 +187,8 @@ def draw_bounding_boxes(
186187
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
187188
font_size (int): The requested font size in points.
188189
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
189-
`colors` argument for details. Defaults to the same colors used for the boxes.
190+
`colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
191+
fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
190192
191193
Returns:
192194
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
@@ -223,8 +225,8 @@ def draw_bounding_boxes(
223225
)
224226

225227
colors = _parse_colors(colors, num_objects=num_boxes)
226-
if label_colors:
227-
label_colors = _parse_colors(label_colors, num_objects=num_boxes) # type: ignore[assignment]
228+
if label_colors or fill_labels:
229+
label_colors = _parse_colors(label_colors if label_colors else "black", num_objects=num_boxes) # type: ignore[assignment]
228230
else:
229231
label_colors = colors.copy() # type: ignore[assignment]
230232

@@ -259,7 +261,13 @@ def draw_bounding_boxes(
259261
draw.rectangle(bbox, width=width, outline=color)
260262

261263
if label is not None:
262-
margin = width + 1
264+
box_margin = 1
265+
margin = width + box_margin
266+
if fill_labels:
267+
left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font)
268+
draw.rectangle(
269+
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=color
270+
)
263271
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]
264272

265273
out = F.pil_to_tensor(img_to_draw)

0 commit comments

Comments
 (0)