From ba9db1275313b9cca23ab0025702329ac34a933a Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sun, 31 Jan 2021 01:31:43 +0530 Subject: [PATCH 01/24] add draw segm masks --- docs/source/utils.rst | 4 +++- test/test_utils.py | 15 ++++++++++++ torchvision/utils.py | 55 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 0ae450487e3..acaf785d817 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -7,4 +7,6 @@ torchvision.utils .. autofunction:: save_image -.. autofunction:: draw_bounding_boxes \ No newline at end of file +.. autofunction:: draw_bounding_boxes + +.. autofunction:: draw_segmentation_masks diff --git a/test/test_utils.py b/test/test_utils.py index 21e2ab461d7..565dab48aa3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -96,6 +96,21 @@ def test_draw_boxes(self): expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) self.assertTrue(torch.equal(result, expected)) + def test_draw_segmentation_masks(self): + img = torch.full((3, 20, 20), 255, dtype=torch.uint8) + colors = ["green", "#FF00FF", (0, 255, 0), "red"] + labels = ["a", "b", "c", "d"] + boxes = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]]) + result = utils.draw_segmentation_masks(img, boxes, labels=labels, colors=colors) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_segm_masks_util.png") + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + self.assertTrue(torch.equal(result, expected)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index 6290809a7d6..a1871de82cb 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -7,7 +7,7 @@ from PIL import Image, ImageDraw from PIL import ImageFont -__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] @torch.no_grad() @@ -189,3 +189,56 @@ def draw_bounding_boxes( draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, + font: Optional[str] = None, + font_size: int = 10 +) -> torch.Tensor: + + """ + Draws segmentation masks on given image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (C x H x W) + masks (Tensor): Tensor of shape (H, W). Each containing predicted class. + labels (List[str]): List containing the labels of masks. + colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can + be represented as `str` or `Tuple[int, int, int]`. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + + ndarr = image.permute(1, 2, 0).numpy() + img_to_draw = Image.fromarray(ndarr) + + img_preds = masks.to(torch.int64).tolist() + + draw = ImageDraw.Draw(img_to_draw) + txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + + for i in range(len(img_preds)): + for j in range(len(img_preds)): + draw.point((i, j), fill=colors[img_preds[i][j]]) + + if labels is not None: + # Should we plot the text ? + # draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) + pass + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) From 9bb4191c5bcacb405200b775e4ff247a9bd51ad5 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 21:21:42 +0530 Subject: [PATCH 02/24] rewrites with new api --- test/test_utils.py | 28 +++++++++++++++++++++++----- torchvision/utils.py | 38 ++++++++++++++++++-------------------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 9bd0c4621d0..233d8f707e2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -96,14 +96,32 @@ def test_draw_boxes(self): expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) self.assertTrue(torch.equal(result, expected)) - def test_draw_segmentation_masks(self): + def test_draw_segmentation_masks_colors(self): img = torch.full((3, 20, 20), 255, dtype=torch.uint8) colors = ["green", "#FF00FF", (0, 255, 0), "red"] - labels = ["a", "b", "c", "d"] - boxes = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]]) - result = utils.draw_segmentation_masks(img, boxes, labels=labels, colors=colors) + # TODO + masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + result = utils.draw_segmentation_masks(img, masks, colors=colors) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", + "fakedata", "draw_segm_masks_colors_util.png") + + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + self.assertTrue(torch.equal(result, expected)) + + def test_draw_segmentation_masks_no_colors(self): + img = torch.full((3, 20, 20), 255, dtype=torch.uint8) + # TODO + masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + result = utils.draw_segmentation_masks(img, masks, colors=None) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", + "fakedata", "draw_segm_masks_no_colors_util.png") - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_segm_masks_util.png") if not os.path.exists(path): res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) res.save(path) diff --git a/torchvision/utils.py b/torchvision/utils.py index a7dd687668e..16ceaa91b00 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,10 +216,7 @@ def draw_bounding_boxes( def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, - labels: Optional[List[str]] = None, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, - font: Optional[str] = None, - font_size: int = 10 ) -> torch.Tensor: """ @@ -232,10 +229,6 @@ def draw_segmentation_masks( labels (List[str]): List containing the labels of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. - font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may - also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, - `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. - font_size (int): The requested font size in points. """ if not isinstance(image, torch.Tensor): @@ -245,21 +238,26 @@ def draw_segmentation_masks( elif image.dim() != 3: raise ValueError("Pass individual images, not batches") - ndarr = image.permute(1, 2, 0).numpy() - img_to_draw = Image.fromarray(ndarr) + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size) - img_preds = masks.to(torch.int64).tolist() - - draw = ImageDraw.Draw(img_to_draw) - txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + if colors is None: + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + colors = torch.as_tensor([i for i in range(21)])[:, None] * palette + color_arr = (colors % 255).numpy().astype("uint8") - for i in range(len(img_preds)): - for j in range(len(img_preds)): - draw.point((i, j), fill=colors[img_preds[i][j]]) + else: + color_list = [] + for color in colors: + if isinstance(color, str): + # This will automatically raise Error if rgb cannot be parsed. + fill_color = ImageColor.getrgb(color) # + (100,) + color_list.append(fill_color) + elif isinstance(color, tuple): + # fill_color = color + (100,) + # Use the given colors list and create ndarray of colors. + color_list.append(color) - if labels is not None: - # Should we plot the text ? - # draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) - pass + color_arr = np.array(color_list).astype("uint8") + img_to_draw.putpalette(color_arr) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) From d3e28b78ee3b5087bafa33e7fd9c5f14a5dddbde Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 21:39:01 +0530 Subject: [PATCH 03/24] fix flaky colors --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 16ceaa91b00..1a5a3a26013 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -242,8 +242,8 @@ def draw_segmentation_masks( if colors is None: palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors = torch.as_tensor([i for i in range(21)])[:, None] * palette - color_arr = (colors % 255).numpy().astype("uint8") + colors_t = torch.as_tensor([i for i in range(21)])[:, None] * palette + color_arr = (colors_t % 255).numpy().astype("uint8") else: color_list = [] From 313f6a6ea75cc1a8e5ea2b3010e80d700513ac3b Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 22:42:22 +0530 Subject: [PATCH 04/24] fix resize bug --- torchvision/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 1a5a3a26013..81fdaca8206 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -238,7 +238,9 @@ def draw_segmentation_masks( elif image.dim() != 3: raise ValueError("Pass individual images, not batches") - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size) + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()) + # ndarr = image.permute(1, 2, 0).numpy() + # img_to_draw = Image.fromarray(ndarr) if colors is None: palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) @@ -258,6 +260,11 @@ def draw_segmentation_masks( color_list.append(color) color_arr = np.array(color_list).astype("uint8") + # print(color_list) + # print(color_arr) img_to_draw.putpalette(color_arr) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + # print(img_to_draw) + # print(color_arr) + # return torch.from_numpy(np.array(img_to_draw)) + return torch.from_numpy(np.array(img_to_draw)).unsqueeze_(0).repeat(3, 1, 1) From 80e4e4ad38320f05daa01cc55905a0f5afab433c Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 22:50:11 +0530 Subject: [PATCH 05/24] resize for sanity --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 81fdaca8206..ca7c431c44d 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -238,7 +238,7 @@ def draw_segmentation_masks( elif image.dim() != 3: raise ValueError("Pass individual images, not batches") - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()) + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) # ndarr = image.permute(1, 2, 0).numpy() # img_to_draw = Image.fromarray(ndarr) From 0388c7accd8c78c0524e6933b630ea44ea05fbe0 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 22:53:42 +0530 Subject: [PATCH 06/24] cleanup --- torchvision/utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index ca7c431c44d..622c98aa558 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -239,8 +239,6 @@ def draw_segmentation_masks( raise ValueError("Pass individual images, not batches") img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) - # ndarr = image.permute(1, 2, 0).numpy() - # img_to_draw = Image.fromarray(ndarr) if colors is None: palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) @@ -260,11 +258,6 @@ def draw_segmentation_masks( color_list.append(color) color_arr = np.array(color_list).astype("uint8") - # print(color_list) - # print(color_arr) img_to_draw.putpalette(color_arr) - # print(img_to_draw) - # print(color_arr) - # return torch.from_numpy(np.array(img_to_draw)) return torch.from_numpy(np.array(img_to_draw)).unsqueeze_(0).repeat(3, 1, 1) From d6018e8e8f5edbc93721c012441a80294f84ae25 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 2 Feb 2021 23:09:34 +0530 Subject: [PATCH 07/24] project the image --- torchvision/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 622c98aa558..492c019eab5 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -241,6 +241,7 @@ def draw_segmentation_masks( img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) if colors is None: + # Default color palette assumes 21 classes. palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors_t = torch.as_tensor([i for i in range(21)])[:, None] * palette color_arr = (colors_t % 255).numpy().astype("uint8") @@ -260,4 +261,8 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") img_to_draw.putpalette(color_arr) - return torch.from_numpy(np.array(img_to_draw)).unsqueeze_(0).repeat(3, 1, 1) + img_to_draw = torch.from_numpy(np.array(img_to_draw)) + + # Project the drawn image to orignal one + image[: 1] = img_to_draw + return image From 4af55496c8dbb675c192fe694218c809b60a2908 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 4 Feb 2021 22:27:51 +0530 Subject: [PATCH 08/24] Minor refactor to adopt num classes --- torchvision/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 492c019eab5..beba41b525e 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -225,7 +225,7 @@ def draw_segmentation_masks( Args: image (Tensor): Tensor of shape (C x H x W) - masks (Tensor): Tensor of shape (H, W). Each containing predicted class. + masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. labels (List[str]): List containing the labels of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. @@ -238,12 +238,14 @@ def draw_segmentation_masks( elif image.dim() != 3: raise ValueError("Pass individual images, not batches") + num_classes = masks.size()[0] + masks = masks.argmax(0) + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) if colors is None: - # Default color palette assumes 21 classes. palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(21)])[:, None] * palette + colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette color_arr = (colors_t % 255).numpy().astype("uint8") else: From f11fa6136336c73fd3d51a45edadad3df4f2b6b7 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 4 Feb 2021 23:14:08 +0530 Subject: [PATCH 09/24] add uint8 in docstring --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index beba41b525e..7113b1460de 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -153,7 +153,7 @@ def draw_bounding_boxes( If filled, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -220,7 +220,7 @@ def draw_segmentation_masks( ) -> torch.Tensor: """ - Draws segmentation masks on given image. + Draws segmentation masks on given image and dtype uint8. The values of the input image should be uint8 between 0 and 255. Args: From d554c7594147ae0a2643f71e94d0b541a3d05778 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 6 Feb 2021 21:42:43 +0530 Subject: [PATCH 10/24] adds alpha and docstring --- torchvision/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 7113b1460de..a99f1eb813d 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,6 +216,7 @@ def draw_bounding_boxes( def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, + alpha: Optional[int] = 0.2, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: @@ -226,7 +227,7 @@ def draw_segmentation_masks( Args: image (Tensor): Tensor of shape (C x H x W) masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. - labels (List[str]): List containing the labels of masks. + alpha (int): Integer denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. """ @@ -247,7 +248,7 @@ def draw_segmentation_masks( palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette color_arr = (colors_t % 255).numpy().astype("uint8") - + color_arr[1:, 3] = 255 else: color_list = [] for color in colors: @@ -263,8 +264,9 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") img_to_draw.putpalette(color_arr) - img_to_draw = torch.from_numpy(np.array(img_to_draw)) - # Project the drawn image to orignal one - image[: 1] = img_to_draw - return image + img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA'))) + img_to_draw = img_to_draw.permute((2, 1, 0)) + + return (torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float() + * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) From 301b9de12e558c8aaf5285c48c8902f87e587e73 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 6 Feb 2021 21:55:59 +0530 Subject: [PATCH 11/24] move code a bit down --- torchvision/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index a99f1eb813d..9998df94ebd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -242,8 +242,6 @@ def draw_segmentation_masks( num_classes = masks.size()[0] masks = masks.argmax(0) - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) - if colors is None: palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette @@ -263,6 +261,7 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) img_to_draw.putpalette(color_arr) img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA'))) From 02226fab05d3ea77a4e89ddc15e49a407e74682e Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 8 Feb 2021 20:12:42 +0530 Subject: [PATCH 12/24] Minor fix --- torchvision/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 9998df94ebd..4cfc914b5de 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -221,11 +221,11 @@ def draw_segmentation_masks( ) -> torch.Tensor: """ - Draws segmentation masks on given image and dtype uint8. + Draws segmentation masks on given image. The values of the input image should be uint8 between 0 and 255. Args: - image (Tensor): Tensor of shape (C x H x W) + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. alpha (int): Integer denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can @@ -255,7 +255,6 @@ def draw_segmentation_masks( fill_color = ImageColor.getrgb(color) # + (100,) color_list.append(fill_color) elif isinstance(color, tuple): - # fill_color = color + (100,) # Use the given colors list and create ndarray of colors. color_list.append(color) From a720f9113bb8e0a057d8036cfc1ba7a44ba9c8da Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 8 Feb 2021 22:23:11 +0530 Subject: [PATCH 13/24] fix type check --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 4cfc914b5de..b05e648566e 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,7 +216,7 @@ def draw_bounding_boxes( def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, - alpha: Optional[int] = 0.2, + alpha: Optional[float] = 0.2, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: @@ -227,7 +227,7 @@ def draw_segmentation_masks( Args: image (Tensor): Tensor of shape (C x H x W) and dtype uint8. masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. - alpha (int): Integer denoting factor of transpaerency of masks. + alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. """ From 155c568a418cae9683fc2e9394add9816c5771f6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Feb 2021 18:00:29 +0000 Subject: [PATCH 14/24] Fixing resize bug. --- torchvision/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index b05e648566e..c2635166d58 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -246,7 +246,6 @@ def draw_segmentation_masks( palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette color_arr = (colors_t % 255).numpy().astype("uint8") - color_arr[1:, 3] = 255 else: color_list = [] for color in colors: @@ -260,11 +259,12 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) + _, h, w = image.size() + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) img_to_draw.putpalette(color_arr) img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA'))) - img_to_draw = img_to_draw.permute((2, 1, 0)) + img_to_draw = img_to_draw.permute((2, 0, 1)) return (torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) From 8e6df85c37edb1b58ceb72587583c35769b90c8b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Feb 2021 18:13:03 +0000 Subject: [PATCH 15/24] Fix type of alpha. --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index c2635166d58..e8743219714 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,7 +216,7 @@ def draw_bounding_boxes( def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, - alpha: Optional[float] = 0.2, + alpha: float = 0.2, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: From ed5de580fa2db3318ced58ce1cc9598b6e15396a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 8 Feb 2021 18:34:01 +0000 Subject: [PATCH 16/24] Remove unnecessary RGBA conversions. --- torchvision/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index e8743219714..873ae3754c5 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -263,8 +263,7 @@ def draw_segmentation_masks( img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) img_to_draw.putpalette(color_arr) - img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA'))) + img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) img_to_draw = img_to_draw.permute((2, 0, 1)) - return (torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float() - * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) + return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) From 731ab330ac171ce82775c6811d3f4d475457cf09 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 9 Feb 2021 17:34:29 +0530 Subject: [PATCH 17/24] update docs to supported only rgb --- torchvision/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 873ae3754c5..1f79ba95ce2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -221,11 +221,11 @@ def draw_segmentation_masks( ) -> torch.Tensor: """ - Draws segmentation masks on given image. + Draws segmentation masks on given RGB image. The values of the input image should be uint8 between 0 and 255. Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can @@ -238,6 +238,8 @@ def draw_segmentation_masks( raise ValueError(f"Tensor uint8 expected, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") num_classes = masks.size()[0] masks = masks.argmax(0) From e582a7473c56da02f254ba4fcb39119fd2a4e5e3 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 9 Feb 2021 18:33:44 +0530 Subject: [PATCH 18/24] minor edits --- torchvision/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 1f79ba95ce2..7ac461aa303 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -253,10 +253,9 @@ def draw_segmentation_masks( for color in colors: if isinstance(color, str): # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) # + (100,) + fill_color = ImageColor.getrgb(color) color_list.append(fill_color) elif isinstance(color, tuple): - # Use the given colors list and create ndarray of colors. color_list.append(color) color_arr = np.array(color_list).astype("uint8") From fd6111839eeb49168d710398782e8dba1745e5f7 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 9 Feb 2021 18:34:48 +0530 Subject: [PATCH 19/24] adds tests --- .../fakedata/draw_segm_masks_colors_util.png | Bin 0 -> 88 bytes .../draw_segm_masks_no_colors_util.png | Bin 0 -> 106 bytes test/test_utils.py | 58 ++++++++++++++++-- 3 files changed, 52 insertions(+), 6 deletions(-) create mode 100644 test/assets/fakedata/draw_segm_masks_colors_util.png create mode 100644 test/assets/fakedata/draw_segm_masks_no_colors_util.png diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png new file mode 100644 index 0000000000000000000000000000000000000000..454b35556317dc1da1707fb234cf8563c1e8c707 GIT binary patch literal 88 zcmeAS@N?(olHy`uVBq!ia0vp^tRT$61SFYwH*Nw_@}4e^Ar*6yP5$MdX<(7~Fa6*B l;o_6Vh6|XE{XeGhiNULzE&Y{;O%+fngQu&X%Q~loCIFN+8JPe8 literal 0 HcmV?d00001 diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png new file mode 100644 index 0000000000000000000000000000000000000000..f048d2469d2414d6e1e864111a6117a30a7d210b GIT binary patch literal 106 zcmeAS@N?(olHy`uVBq!ia0vp^A|TAc1SFYWcSQjyLr)jSkcv6UCi5Pib Date: Tue, 9 Feb 2021 19:00:36 +0530 Subject: [PATCH 20/24] shifts masks up --- test/test_utils.py | 74 +++++++++++++++------------------------------- 1 file changed, 24 insertions(+), 50 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 52452e7b4a5..3ebd65a51b5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,6 +9,30 @@ import torchvision.transforms.functional as F from PIL import Image +masks = torch.tensor([ + [ + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] + ], + [ + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] + ], + [ + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + ] +], dtype=torch.float) + class Tester(unittest.TestCase): @@ -99,31 +123,6 @@ def test_draw_boxes(self): def test_draw_segmentation_masks_colors(self): img = torch.full((3, 5, 5), 255, dtype=torch.uint8) colors = ["#FF00FF", (0, 255, 0), "red"] - masks = torch.tensor( - [ - [ - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] - ], - - [ - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] - ], - [ - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - ] - ], dtype=torch.float) result = utils.draw_segmentation_masks(img, masks, colors=colors) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", @@ -138,31 +137,6 @@ def test_draw_segmentation_masks_colors(self): def test_draw_segmentation_masks_no_colors(self): img = torch.full((3, 20, 20), 255, dtype=torch.uint8) - masks = torch.tensor( - [ - [ - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] - ], - - [ - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] - ], - [ - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - ] - ], dtype=torch.float) result = utils.draw_segmentation_masks(img, masks, colors=None) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", From 58b3870a03e0cff58bf59fc0333361f3697fa7cd Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 10 Feb 2021 22:18:30 +0530 Subject: [PATCH 21/24] change tests and impelementation for bool --- test/test_utils.py | 34 +++++++++++++++++----------------- torchvision/utils.py | 8 ++++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3ebd65a51b5..b7cb273bff1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,27 +11,27 @@ masks = torch.tensor([ [ - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] + [False, False, False, False, False], + [True, True, True, True, True], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False] ], [ - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] + [True, True, True, True, True], + [False, False, False, False, False], + [True, True, True, True, True], + [True, True, True, True, True], + [False, False, False, False, False] ], [ - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], - [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [True, True, True, True, True], ] -], dtype=torch.float) +], dtype=torch.bool) class Tester(unittest.TestCase): @@ -136,7 +136,7 @@ def test_draw_segmentation_masks_colors(self): self.assertTrue(torch.equal(result, expected)) def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 20, 20), 255, dtype=torch.uint8) + img = torch.full((3, 5, 5), 255, dtype=torch.uint8) result = utils.draw_segmentation_masks(img, masks, colors=None) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", diff --git a/torchvision/utils.py b/torchvision/utils.py index 7ac461aa303..b9857c9f7a2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -226,7 +226,7 @@ def draw_segmentation_masks( Args: image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. + masks (Tensor): Boolean Tensor of shape (num_masks, H, W). alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. @@ -241,12 +241,12 @@ def draw_segmentation_masks( elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") - num_classes = masks.size()[0] + num_masks = masks.size()[0] masks = masks.argmax(0) if colors is None: palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette + colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette color_arr = (colors_t % 255).numpy().astype("uint8") else: color_list = [] @@ -261,7 +261,7 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") _, h, w = image.size() - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) + img_to_draw = Image.fromarray(masks.cpu().numpy(), mode="1").resize((w, h)) img_to_draw.putpalette(color_arr) img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) From fbf4dc78229f951d2fede821fedc57913f051ab8 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 10 Feb 2021 22:40:41 +0530 Subject: [PATCH 22/24] change mode to L --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index b9857c9f7a2..4d450c103b9 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -261,7 +261,7 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") _, h, w = image.size() - img_to_draw = Image.fromarray(masks.cpu().numpy(), mode="1").resize((w, h)) + img_to_draw = Image.fromarray(masks.cpu().numpy(), mode="L").resize((w, h)) img_to_draw.putpalette(color_arr) img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) From bc81e50a888f5785125483da4e04d05fbdae1dec Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 19 Mar 2021 22:17:10 +0530 Subject: [PATCH 23/24] convert to float --- test/test_utils.py | 34 +++++++++++++++++----------------- torchvision/utils.py | 5 ++--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 030bc985cc4..fcf05edd11a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,27 +11,27 @@ masks = torch.tensor([ [ - [False, False, False, False, False], - [True, True, True, True, True], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False] + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] ], [ - [True, True, True, True, True], - [False, False, False, False, False], - [True, True, True, True, True], - [True, True, True, True, True], - [False, False, False, False, False] + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] ], [ - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [True, True, True, True, True], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], ] -], dtype=torch.bool) +], dtype=torch.float) class Tester(unittest.TestCase): @@ -136,7 +136,7 @@ def test_draw_segmentation_masks_colors(self): self.assertTrue(torch.equal(result, expected)) def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 5, 5), 255, dtype=torch.uint8) + img = torch.full((3, 20, 20), 255, dtype=torch.uint8) result = utils.draw_segmentation_masks(img, masks, colors=None) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", diff --git a/torchvision/utils.py b/torchvision/utils.py index 4d450c103b9..ea75c8ce096 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -223,10 +223,9 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. The values of the input image should be uint8 between 0 and 255. - Args: image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Boolean Tensor of shape (num_masks, H, W). + masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can be represented as `str` or `Tuple[int, int, int]`. @@ -261,7 +260,7 @@ def draw_segmentation_masks( color_arr = np.array(color_list).astype("uint8") _, h, w = image.size() - img_to_draw = Image.fromarray(masks.cpu().numpy(), mode="L").resize((w, h)) + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) img_to_draw.putpalette(color_arr) img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) From 32de89bc398619b62de9e5a733d218072990f95a Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 20 Mar 2021 11:29:31 +0530 Subject: [PATCH 24/24] fixes docs --- torchvision/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/utils.py b/torchvision/utils.py index ea75c8ce096..54cf4c3e4c2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -223,6 +223,7 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. The values of the input image should be uint8 between 0 and 255. + Args: image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.