From 49d81db480f0fc33a9ff6f5decd2d134426bd7ae Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 25 Dec 2021 22:33:52 +0530 Subject: [PATCH 01/14] Add random colors --- test/test_utils.py | 2 +- torchvision/utils.py | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 30f144d8206..2c1ac8e2c3e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -124,7 +124,7 @@ def test_draw_boxes_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() boxes_cp = boxes.clone() - result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7) + result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png") if not os.path.exists(path): diff --git a/torchvision/utils.py b/torchvision/utils.py index cbbccae0b95..e1b90d506c9 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -169,6 +169,8 @@ def draw_bounding_boxes( colors (color or list of colors, optional): List containing the colors of the boxes or single color for all boxes. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. + If labels are provided, boxes with same labels have same color. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may @@ -191,6 +193,10 @@ def draw_bounding_boxes( elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") + if labels: + if len(labels) != boxes.size(0): + raise ValueError("Specify labels for each box") + if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) @@ -207,10 +213,13 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + if colors is None: + colors = _generate_random_color_palette(len(img_boxes)) + label_color_map = dict(zip(labels, colors)) + colors = [label_color_map[label] for label in labels] + for i, bbox in enumerate(img_boxes): - if colors is None: - color = None - elif isinstance(colors, list): + if isinstance(colors, list): color = colors[i] else: color = colors @@ -387,6 +396,26 @@ def _generate_color_palette(num_masks: int): return [tuple((i * palette) % 255) for i in range(num_masks)] +def _generate_random_color() -> List[int]: + """ + Generates a random RGB Color. + Returns: + color(int, int, int): A List containing Random RGB Color. + """ + return torch.randperm(255)[:3].tolist() + + +def _generate_random_color_palette(num_colors) -> List[List[int]]: + """ + Args: + num_colors (int): Integer denoting number of random RGB colors to generate + + Returns: + colors(List[color]): A list each containing random RGB color. Each color is List containing RGB values. + """ + return [_generate_random_color() for i in range(num_colors)] + + def _log_api_usage_once(obj: Any) -> None: if not obj.__module__.startswith("torchvision"): return From 67cc6c082b55cb73f0b28ebdc6ffb9c2358c3a70 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 25 Dec 2021 23:06:14 +0530 Subject: [PATCH 02/14] Update error message, pretty the code --- torchvision/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e1b90d506c9..8593f2c9686 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -192,22 +192,20 @@ def draw_bounding_boxes( raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") - if labels: if len(labels) != boxes.size(0): - raise ValueError("Specify labels for each box") + raise ValueError("Length of boxes and labels mismatch. Please specify labels for each box.") + # Handle Grayscale images if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) - img_boxes = boxes.to(torch.int64).tolist() if fill: draw = ImageDraw.Draw(img_to_draw, "RGBA") - else: draw = ImageDraw.Draw(img_to_draw) From 3aabd21ccd5f85c68c7a448993651462ada485a5 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 25 Dec 2021 23:07:10 +0530 Subject: [PATCH 03/14] Update edge cases --- test/test_utils.py | 4 ++++ torchvision/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 2c1ac8e2c3e..bb6dec87a62 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -149,7 +149,9 @@ def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) + img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + labels_wrong = ["one", "two"] with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) with pytest.raises(ValueError, match="Tensor uint8 expected"): @@ -158,6 +160,8 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2, boxes) with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) + with pytest.raises(ValueError, match="Length of boxes and labels mismatch."): + utils.draw_bounding_boxes(img_correct, boxes, labels_wrong) @pytest.mark.parametrize( diff --git a/torchvision/utils.py b/torchvision/utils.py index 8593f2c9686..b0650e5f826 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -400,10 +400,10 @@ def _generate_random_color() -> List[int]: Returns: color(int, int, int): A List containing Random RGB Color. """ - return torch.randperm(255)[:3].tolist() + return torch.randperm(256)[:3].tolist() -def _generate_random_color_palette(num_colors) -> List[List[int]]: +def _generate_random_color_palette(num_colors: int) -> List[List[int]]: """ Args: num_colors (int): Integer denoting number of random RGB colors to generate From 0566ee3c337f8b21114aadd9dee0e2bc86a31bbb Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 25 Dec 2021 23:36:18 +0530 Subject: [PATCH 04/14] Change implementation to tuples --- torchvision/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index b0650e5f826..c294d22413f 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -394,22 +394,22 @@ def _generate_color_palette(num_masks: int): return [tuple((i * palette) % 255) for i in range(num_masks)] -def _generate_random_color() -> List[int]: +def _generate_random_color() -> Tuple[int, int, int]: """ Generates a random RGB Color. Returns: - color(int, int, int): A List containing Random RGB Color. + color Tuple(int, int, int): A Tuple containing Random RGB Color. """ - return torch.randperm(256)[:3].tolist() + return tuple(torch.randperm(256)[:3].tolist()) -def _generate_random_color_palette(num_colors: int) -> List[List[int]]: +def _generate_random_color_palette(num_colors) -> List[Tuple[int, int, int]]: """ Args: num_colors (int): Integer denoting number of random RGB colors to generate Returns: - colors(List[color]): A list each containing random RGB color. Each color is List containing RGB values. + colors(List[color]): A List each containing random RGB color. Each color is Tuple containing RGB values. """ return [_generate_random_color() for i in range(num_colors)] From 900ff7c70fabb4c6c8f9f1259c0fbad03168e80f Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 27 Dec 2021 00:43:58 +0530 Subject: [PATCH 05/14] Fix bugs --- test/test_utils.py | 6 +++++- torchvision/utils.py | 26 ++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index bb6dec87a62..c671d2ee8bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -152,6 +152,8 @@ def test_draw_invalid_boxes(): img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) labels_wrong = ["one", "two"] + colors_wrong = ["pink", "blue"] + with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) with pytest.raises(ValueError, match="Tensor uint8 expected"): @@ -160,8 +162,10 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2, boxes) with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) - with pytest.raises(ValueError, match="Length of boxes and labels mismatch."): + with pytest.raises(ValueError, match="Number of boxes and labels mismatch."): utils.draw_bounding_boxes(img_correct, boxes, labels_wrong) + with pytest.raises(ValueError, match="Number of colors should be greater or equal to the number of boxes."): + utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong) @pytest.mark.parametrize( diff --git a/torchvision/utils.py b/torchvision/utils.py index c294d22413f..526a4b5c45a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -192,9 +192,16 @@ def draw_bounding_boxes( raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") + if labels: if len(labels) != boxes.size(0): - raise ValueError("Length of boxes and labels mismatch. Please specify labels for each box.") + raise ValueError("Number of boxes and labels mismatch. Please specify labels for each box.") + + # List of colors should have minimum colors. + if colors: + if isinstance(colors, list): + if len(colors) < boxes.size(0): + raise ValueError("Number of colors should be greater or equal to the number of boxes.") # Handle Grayscale images if image.size(0) == 1: @@ -213,8 +220,9 @@ def draw_bounding_boxes( if colors is None: colors = _generate_random_color_palette(len(img_boxes)) - label_color_map = dict(zip(labels, colors)) - colors = [label_color_map[label] for label in labels] + if labels is not None: + label_color_map = dict(zip(labels, colors)) + colors = [label_color_map[label] for label in labels] for i, bbox in enumerate(img_boxes): if isinstance(colors, list): @@ -223,9 +231,7 @@ def draw_bounding_boxes( color = colors if fill: - if color is None: - fill_color = (255, 255, 255, 100) - elif isinstance(color, str): + if isinstance(color, str): # This will automatically raise Error if rgb cannot be parsed. fill_color = ImageColor.getrgb(color) + (100,) elif isinstance(color, tuple): @@ -397,19 +403,23 @@ def _generate_color_palette(num_masks: int): def _generate_random_color() -> Tuple[int, int, int]: """ Generates a random RGB Color. + Returns: color Tuple(int, int, int): A Tuple containing Random RGB Color. """ return tuple(torch.randperm(256)[:3].tolist()) -def _generate_random_color_palette(num_colors) -> List[Tuple[int, int, int]]: +def _generate_random_color_palette(num_colors: int) -> List[Tuple[int, int, int]]: """ + Generates a random RGB Color palette. + Args: num_colors (int): Integer denoting number of random RGB colors to generate Returns: - colors(List[color]): A List each containing random RGB color. Each color is Tuple containing RGB values. + colors(List[Tuple[int, int, int]): A List of Tuples each containing random RGB color. + Each color is Tuple containing RGB values. """ return [_generate_random_color() for i in range(num_colors)] From 57fd9b07e00df7128ad4e64b20c2e29607b687c7 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sun, 9 Jan 2022 21:53:25 +0530 Subject: [PATCH 06/14] Add tests --- test/test_utils.py | 13 +++++++++++++ torchvision/utils.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index c671d2ee8bc..b9781a57ddb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -325,5 +325,18 @@ def test_draw_keypoints_errors(): utils.draw_keypoints(image=img, keypoints=invalid_keypoints) +def test_random_colors(): + color_t = utils._generate_random_color() + assert isinstance(color_t, tuple) + assert len(color_t) == 3 + assert 256 not in color_t + + +def test_random_color_palette(): + color_plt = utils._generate_random_color_palette(5) + assert len(color_plt) == 5 + assert isinstance(color_plt, list) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/utils.py b/torchvision/utils.py index 526a4b5c45a..488157a515d 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -407,7 +407,7 @@ def _generate_random_color() -> Tuple[int, int, int]: Returns: color Tuple(int, int, int): A Tuple containing Random RGB Color. """ - return tuple(torch.randperm(256)[:3].tolist()) + return tuple(torch.randint(256, (3,)).tolist()) def _generate_random_color_palette(num_colors: int) -> List[Tuple[int, int, int]]: From 8c4d70f0c3afa1c2349e4db603b992fcccc50de6 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sun, 16 Jan 2022 23:45:51 +0530 Subject: [PATCH 07/14] Reuse palette --- test/test_utils.py | 13 ------------- torchvision/utils.py | 42 ++++++++++-------------------------------- 2 files changed, 10 insertions(+), 45 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index b9781a57ddb..c671d2ee8bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -325,18 +325,5 @@ def test_draw_keypoints_errors(): utils.draw_keypoints(image=img, keypoints=invalid_keypoints) -def test_random_colors(): - color_t = utils._generate_random_color() - assert isinstance(color_t, tuple) - assert len(color_t) == 3 - assert 256 not in color_t - - -def test_random_color_palette(): - color_plt = utils._generate_random_color_palette(5) - assert len(color_plt) == 5 - assert isinstance(color_plt, list) - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/utils.py b/torchvision/utils.py index 488157a515d..86b6f7d286b 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -169,8 +169,7 @@ def draw_bounding_boxes( colors (color or list of colors, optional): List containing the colors of the boxes or single color for all boxes. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, random colors are generated for boxes. - If labels are provided, boxes with same labels have same color. + By default, fixed colors are generated for boxes. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may @@ -219,10 +218,7 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) if colors is None: - colors = _generate_random_color_palette(len(img_boxes)) - if labels is not None: - label_color_map = dict(zip(labels, colors)) - colors = [label_color_map[label] for label in labels] + colors = _generate_color_palette(len(img_boxes)) for i, bbox in enumerate(img_boxes): if isinstance(colors, list): @@ -395,33 +391,15 @@ def draw_keypoints( return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) -def _generate_color_palette(num_masks: int): +def _generate_color_palette(num_masks: int, return_tensor: bool = True): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - return [tuple((i * palette) % 255) for i in range(num_masks)] - - -def _generate_random_color() -> Tuple[int, int, int]: - """ - Generates a random RGB Color. - - Returns: - color Tuple(int, int, int): A Tuple containing Random RGB Color. - """ - return tuple(torch.randint(256, (3,)).tolist()) - - -def _generate_random_color_palette(num_colors: int) -> List[Tuple[int, int, int]]: - """ - Generates a random RGB Color palette. - - Args: - num_colors (int): Integer denoting number of random RGB colors to generate - - Returns: - colors(List[Tuple[int, int, int]): A List of Tuples each containing random RGB color. - Each color is Tuple containing RGB values. - """ - return [_generate_random_color() for i in range(num_colors)] + clrs = [] + for i in range(num_masks): + if return_tensor: + clrs.append((tuple((palette * i) % 255))) + else: + clrs.append(tuple(((palette * i) % 255).tolist())) + return clrs def _log_api_usage_once(obj: Any) -> None: From 1f89e2e16507aff3bb3b9978ff7b7943e77a3bd7 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 22 Jan 2022 22:21:23 +0530 Subject: [PATCH 08/14] small rename fix --- torchvision/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e2af7f985fc..6450774b714 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -395,10 +395,11 @@ def _generate_color_palette(num_masks: int, return_tensor: bool = True): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) clrs = [] for i in range(num_masks): + clr = (palette * i) % 255 if return_tensor: - clrs.append((tuple((palette * i) % 255))) + clrs.append(tuple(clr)) else: - clrs.append(tuple(((palette * i) % 255).tolist())) + clrs.append(tuple((clr.tolist()))) return clrs From 587a8d4c8fdc163ed725f4edb666a5191f45ba84 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 27 Jan 2022 23:24:59 +0530 Subject: [PATCH 09/14] Update tests and code --- test/test_utils.py | 4 ++-- torchvision/utils.py | 27 +++++++++++---------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 30e39c4e6b3..0b6b016d9e7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -162,9 +162,9 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2, boxes) with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) - with pytest.raises(ValueError, match="Number of boxes and labels mismatch."): + with pytest.raises(ValueError, match="Number of boxes"): utils.draw_bounding_boxes(img_correct, boxes, labels_wrong) - with pytest.raises(ValueError, match="Number of colors should be greater or equal to the number of boxes."): + with pytest.raises(ValueError, match="Number of colors"): utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong) diff --git a/torchvision/utils.py b/torchvision/utils.py index 562a344e5dc..e7e5024052c 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -201,13 +201,19 @@ def draw_bounding_boxes( if labels: if len(labels) != boxes.size(0): - raise ValueError("Number of boxes and labels mismatch. Please specify labels for each box.") + raise ValueError( + f"Number of boxes ({boxes.size(0)}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) # List of colors should have minimum colors. if colors: if isinstance(colors, list): if len(colors) < boxes.size(0): - raise ValueError("Number of colors should be greater or equal to the number of boxes.") + raise ValueError( + f"Number of colors ({len(colors)}) are less than number of boxes ({boxes.size(0)}). Atleast supply colors for each box." + ) + else: + colors = [colors] * boxes.size(0) # Handle Grayscale images if image.size(0) == 1: @@ -228,11 +234,7 @@ def draw_bounding_boxes( colors = _generate_color_palette(len(img_boxes)) for i, bbox in enumerate(img_boxes): - if isinstance(colors, list): - color = colors[i] - else: - color = colors - + color = colors[i] if fill: if isinstance(color, str): # This will automatically raise Error if rgb cannot be parsed. @@ -505,16 +507,9 @@ def _make_colorwheel() -> torch.Tensor: return colorwheel -def _generate_color_palette(num_masks: int, return_tensor: bool = True): +def _generate_color_palette(num_objects: int): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - clrs = [] - for i in range(num_masks): - clr = (palette * i) % 255 - if return_tensor: - clrs.append(tuple(clr)) - else: - clrs.append(tuple((clr.tolist()))) - return clrs + return [tuple((i * palette) % 255) for i in range(num_objects)] def _log_api_usage_once(obj: Any) -> None: From 36c606dbe7d389a20ed2416731e9b0eaf053800d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Feb 2022 16:00:17 +0000 Subject: [PATCH 10/14] Simplify code --- torchvision/utils.py | 43 ++++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e7e5024052c..8629b24f76c 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -199,21 +199,22 @@ def draw_bounding_boxes( elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") - if labels: - if len(labels) != boxes.size(0): - raise ValueError( - f"Number of boxes ({boxes.size(0)}) and labels ({len(labels)}) mismatch. Please specify labels for each box." - ) - - # List of colors should have minimum colors. - if colors: - if isinstance(colors, list): - if len(colors) < boxes.size(0): - raise ValueError( - f"Number of colors ({len(colors)}) are less than number of boxes ({boxes.size(0)}). Atleast supply colors for each box." - ) - else: - colors = [colors] * boxes.size(0) + num_boxes = boxes.shape[0] + if labels and len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. " + "Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] # Handle Grayscale images if image.size(0) == 1: @@ -230,17 +231,9 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - if colors is None: - colors = _generate_color_palette(len(img_boxes)) - - for i, bbox in enumerate(img_boxes): - color = colors[i] + for i, (bbox, color) in enumerate(zip(img_boxes, colors)): if fill: - if isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) + (100,) - elif isinstance(color, tuple): - fill_color = color + (100,) + fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) else: draw.rectangle(bbox, width=width, outline=color) From f9b4d2f1f368dfd2e78f1a144e45b97e91f5c991 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Feb 2022 16:10:12 +0000 Subject: [PATCH 11/14] ufmt --- torchvision/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 8629b24f76c..52d4789e494 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -202,8 +202,7 @@ def draw_bounding_boxes( num_boxes = boxes.shape[0] if labels and len(labels) != num_boxes: raise ValueError( - f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. " - "Please specify labels for each box." + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." ) if colors is None: From 4c95a5c495485d25bc12fa7f914ead45a0a7ae38 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Feb 2022 16:20:14 +0000 Subject: [PATCH 12/14] fixed colors -> random colors in docstring --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 20f1fbf78c6..019639d4264 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -176,7 +176,7 @@ def draw_bounding_boxes( colors (color or list of colors, optional): List containing the colors of the boxes or single color for all boxes. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - By default, fixed colors are generated for boxes. + By default, random colors are generated for boxes. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may From fcba00af80ad92c2b0aed752acb0549e07485334 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Feb 2022 16:27:19 +0000 Subject: [PATCH 13/14] Actually simplify further --- torchvision/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 019639d4264..2c8c9a96516 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -200,7 +200,10 @@ def draw_bounding_boxes( raise ValueError("Only grayscale and RGB images are supported") num_boxes = boxes.shape[0] - if labels and len(labels) != num_boxes: + + if labels is None: + labels = [None] * num_boxes + elif len(labels) != num_boxes: raise ValueError( f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." ) @@ -230,16 +233,16 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - for i, (bbox, color) in enumerate(zip(img_boxes, colors)): + for bbox, color, label in zip(img_boxes, colors, labels): if fill: fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color) else: draw.rectangle(bbox, width=width, outline=color) - if labels is not None: + if label is not None: margin = width + 1 - draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font) + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) From 056a0333ac7594941745a9976a313596000c1309 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Feb 2022 10:24:58 +0000 Subject: [PATCH 14/14] Silence mypy. Twice. lol. --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 2c8c9a96516..34e36c553dd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -202,7 +202,7 @@ def draw_bounding_boxes( num_boxes = boxes.shape[0] if labels is None: - labels = [None] * num_boxes + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] elif len(labels) != num_boxes: raise ValueError( f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." @@ -233,7 +233,7 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - for bbox, color, label in zip(img_boxes, colors, labels): + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] if fill: fill_color = color + (100,) draw.rectangle(bbox, width=width, outline=color, fill=fill_color)