diff --git a/test/test_utils.py b/test/test_utils.py index ebe35a8cd14..4e3fa401e89 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -124,7 +124,7 @@ def test_draw_boxes_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() boxes_cp = boxes.clone() - result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7) + result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png") if not os.path.exists(path): @@ -149,7 +149,11 @@ def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) + img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + labels_wrong = ["one", "two"] + colors_wrong = ["pink", "blue"] + with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) with pytest.raises(ValueError, match="Tensor uint8 expected"): @@ -158,6 +162,10 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2, boxes) with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) + with pytest.raises(ValueError, match="Number of boxes"): + utils.draw_bounding_boxes(img_correct, boxes, labels_wrong) + with pytest.raises(ValueError, match="Number of colors"): + utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong) @pytest.mark.parametrize( diff --git a/torchvision/utils.py b/torchvision/utils.py index 855f132d645..34e36c553dd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -176,6 +176,7 @@ def draw_bounding_boxes( colors (color or list of colors, optional): List containing the colors of the boxes or single color for all boxes. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may @@ -198,45 +199,50 @@ def draw_bounding_boxes( elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") + num_boxes = boxes.shape[0] + + if labels is None: + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] + elif len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + + # Handle Grayscale images if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) - img_boxes = boxes.to(torch.int64).tolist() if fill: draw = ImageDraw.Draw(img_to_draw, "RGBA") - else: draw = ImageDraw.Draw(img_to_draw) txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - for i, bbox in enumerate(img_boxes): - if colors is None: - color = None - elif isinstance(colors, list): - color = colors[i] - else: - color = colors - + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] if fill: - if color is None: - fill_color = (255, 255, 255, 100) - elif isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) + (100,) - elif isinstance(color, tuple): - fill_color = color + (100,) + fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) else: draw.rectangle(bbox, width=width, outline=color) - if labels is not None: + if label is not None: margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font) + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) @@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor: return colorwheel -def _generate_color_palette(num_masks: int): +def _generate_color_palette(num_objects: int): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - return [tuple((i * palette) % 255) for i in range(num_masks)] + return [tuple((i * palette) % 255) for i in range(num_objects)] def _log_api_usage_once(obj: Any) -> None: