Skip to content

Commit 3f33eeb

Browse files
Support random colors by default for draw_bounding_boxes (#5127)
* Add random colors * Update error message, pretty the code * Update edge cases * Change implementation to tuples * Fix bugs * Add tests * Reuse palette * small rename fix * Update tests and code * Simplify code * ufmt * fixed colors -> random colors in docstring * Actually simplify further * Silence mypy. Twice. lol. Co-authored-by: Nicolas Hug <[email protected]>
1 parent 435eddf commit 3f33eeb

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

test/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_draw_boxes_vanilla():
124124
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
125125
img_cp = img.clone()
126126
boxes_cp = boxes.clone()
127-
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
127+
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
128128

129129
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
130130
if not os.path.exists(path):
@@ -149,7 +149,11 @@ def test_draw_invalid_boxes():
149149
img_tp = ((1, 1, 1), (1, 2, 3))
150150
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
151151
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
152+
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
152153
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
154+
labels_wrong = ["one", "two"]
155+
colors_wrong = ["pink", "blue"]
156+
153157
with pytest.raises(TypeError, match="Tensor expected"):
154158
utils.draw_bounding_boxes(img_tp, boxes)
155159
with pytest.raises(ValueError, match="Tensor uint8 expected"):
@@ -158,6 +162,10 @@ def test_draw_invalid_boxes():
158162
utils.draw_bounding_boxes(img_wrong2, boxes)
159163
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
160164
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
165+
with pytest.raises(ValueError, match="Number of boxes"):
166+
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
167+
with pytest.raises(ValueError, match="Number of colors"):
168+
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
161169

162170

163171
@pytest.mark.parametrize(

torchvision/utils.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def draw_bounding_boxes(
176176
colors (color or list of colors, optional): List containing the colors
177177
of the boxes or single color for all boxes. The color can be represented as
178178
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
179+
By default, random colors are generated for boxes.
179180
fill (bool): If `True` fills the bounding box with specified color.
180181
width (int): Width of bounding box.
181182
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -198,45 +199,50 @@ def draw_bounding_boxes(
198199
elif image.size(0) not in {1, 3}:
199200
raise ValueError("Only grayscale and RGB images are supported")
200201

202+
num_boxes = boxes.shape[0]
203+
204+
if labels is None:
205+
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
206+
elif len(labels) != num_boxes:
207+
raise ValueError(
208+
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
209+
)
210+
211+
if colors is None:
212+
colors = _generate_color_palette(num_boxes)
213+
elif isinstance(colors, list):
214+
if len(colors) < num_boxes:
215+
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
216+
else: # colors specifies a single color for all boxes
217+
colors = [colors] * num_boxes
218+
219+
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
220+
221+
# Handle Grayscale images
201222
if image.size(0) == 1:
202223
image = torch.tile(image, (3, 1, 1))
203224

204225
ndarr = image.permute(1, 2, 0).cpu().numpy()
205226
img_to_draw = Image.fromarray(ndarr)
206-
207227
img_boxes = boxes.to(torch.int64).tolist()
208228

209229
if fill:
210230
draw = ImageDraw.Draw(img_to_draw, "RGBA")
211-
212231
else:
213232
draw = ImageDraw.Draw(img_to_draw)
214233

215234
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
216235

217-
for i, bbox in enumerate(img_boxes):
218-
if colors is None:
219-
color = None
220-
elif isinstance(colors, list):
221-
color = colors[i]
222-
else:
223-
color = colors
224-
236+
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
225237
if fill:
226-
if color is None:
227-
fill_color = (255, 255, 255, 100)
228-
elif isinstance(color, str):
229-
# This will automatically raise Error if rgb cannot be parsed.
230-
fill_color = ImageColor.getrgb(color) + (100,)
231-
elif isinstance(color, tuple):
232-
fill_color = color + (100,)
238+
fill_color = color + (100,)
233239
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
234240
else:
235241
draw.rectangle(bbox, width=width, outline=color)
236242

237-
if labels is not None:
243+
if label is not None:
238244
margin = width + 1
239-
draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font)
245+
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
240246

241247
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
242248

@@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor:
505511
return colorwheel
506512

507513

508-
def _generate_color_palette(num_masks: int):
514+
def _generate_color_palette(num_objects: int):
509515
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
510-
return [tuple((i * palette) % 255) for i in range(num_masks)]
516+
return [tuple((i * palette) % 255) for i in range(num_objects)]
511517

512518

513519
def _log_api_usage_once(obj: Any) -> None:

0 commit comments

Comments
 (0)