diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 0ae450487e3..acaf785d817 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -7,4 +7,6 @@ torchvision.utils .. autofunction:: save_image -.. autofunction:: draw_bounding_boxes \ No newline at end of file +.. autofunction:: draw_bounding_boxes + +.. autofunction:: draw_segmentation_masks diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png new file mode 100644 index 00000000000..454b3555631 Binary files /dev/null and b/test/assets/fakedata/draw_segm_masks_colors_util.png differ diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png new file mode 100644 index 00000000000..f048d2469d2 Binary files /dev/null and b/test/assets/fakedata/draw_segm_masks_no_colors_util.png differ diff --git a/test/test_utils.py b/test/test_utils.py index 1fcee7ce489..fcf05edd11a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,6 +9,30 @@ import torchvision.transforms.functional as F from PIL import Image +masks = torch.tensor([ + [ + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799] + ], + [ + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541] + ], + [ + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541], + [5.0914, 5.0914, 5.0914, 5.0914, 5.0914], + ] +], dtype=torch.float) + class Tester(unittest.TestCase): @@ -96,6 +120,35 @@ def test_draw_boxes(self): expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) self.assertTrue(torch.equal(result, expected)) + def test_draw_segmentation_masks_colors(self): + img = torch.full((3, 5, 5), 255, dtype=torch.uint8) + colors = ["#FF00FF", (0, 255, 0), "red"] + result = utils.draw_segmentation_masks(img, masks, colors=colors) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", + "fakedata", "draw_segm_masks_colors_util.png") + + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + self.assertTrue(torch.equal(result, expected)) + + def test_draw_segmentation_masks_no_colors(self): + img = torch.full((3, 20, 20), 255, dtype=torch.uint8) + result = utils.draw_segmentation_masks(img, masks, colors=None) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", + "fakedata", "draw_segm_masks_no_colors_util.png") + + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + self.assertTrue(torch.equal(result, expected)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index 9ee5a0cc65c..54cf4c3e4c2 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image, ImageDraw, ImageFont, ImageColor -__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] @torch.no_grad() @@ -153,7 +153,7 @@ def draw_bounding_boxes( If filled, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -210,3 +210,61 @@ def draw_bounding_boxes( draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.2, + colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. + alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. + colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can + be represented as `str` or `Tuple[int, int, int]`. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + num_masks = masks.size()[0] + masks = masks.argmax(0) + + if colors is None: + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette + color_arr = (colors_t % 255).numpy().astype("uint8") + else: + color_list = [] + for color in colors: + if isinstance(color, str): + # This will automatically raise Error if rgb cannot be parsed. + fill_color = ImageColor.getrgb(color) + color_list.append(fill_color) + elif isinstance(color, tuple): + color_list.append(color) + + color_arr = np.array(color_list).astype("uint8") + + _, h, w = image.size() + img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) + img_to_draw.putpalette(color_arr) + + img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) + img_to_draw = img_to_draw.permute((2, 0, 1)) + + return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)