Skip to content

Added utility to draw segmentation masks #3330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba9db12
add draw segm masks
oke-aditya Jan 30, 2021
24f076f
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Feb 1, 2021
785a6c2
Merge branch 'master' into add_msks
datumbox Feb 1, 2021
9bb4191
rewrites with new api
oke-aditya Feb 2, 2021
11adde7
Merge branch 'add_msks' of github.com:oke-aditya/vision into add_msks
oke-aditya Feb 2, 2021
d3e28b7
fix flaky colors
oke-aditya Feb 2, 2021
313f6a6
fix resize bug
oke-aditya Feb 2, 2021
80e4e4a
resize for sanity
oke-aditya Feb 2, 2021
0388c7a
cleanup
oke-aditya Feb 2, 2021
d6018e8
project the image
oke-aditya Feb 2, 2021
4af5549
Minor refactor to adopt num classes
oke-aditya Feb 4, 2021
f11fa61
add uint8 in docstring
oke-aditya Feb 4, 2021
d554c75
adds alpha and docstring
oke-aditya Feb 6, 2021
301b9de
move code a bit down
oke-aditya Feb 6, 2021
18b9cf1
Merge branch 'master' into add_msks
oke-aditya Feb 6, 2021
02226fa
Minor fix
oke-aditya Feb 8, 2021
a720f91
fix type check
oke-aditya Feb 8, 2021
155c568
Fixing resize bug.
datumbox Feb 8, 2021
b5f8e76
Merge branch 'master' into add_msks
datumbox Feb 8, 2021
8e6df85
Fix type of alpha.
datumbox Feb 8, 2021
ed5de58
Remove unnecessary RGBA conversions.
datumbox Feb 8, 2021
731ab33
update docs to supported only rgb
oke-aditya Feb 9, 2021
e582a74
minor edits
oke-aditya Feb 9, 2021
fd61118
adds tests
oke-aditya Feb 9, 2021
b0696a3
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Feb 9, 2021
2c7bf60
shifts masks up
oke-aditya Feb 9, 2021
cea79ef
Merge branch 'master' into add_msks
datumbox Feb 10, 2021
0a7e1fa
Merge branch 'master' into add_msks
datumbox Feb 10, 2021
58b3870
change tests and impelementation for bool
oke-aditya Feb 10, 2021
fbf4dc7
change mode to L
oke-aditya Feb 10, 2021
3dbfd67
Merge branch 'master' into add_msks
oke-aditya Feb 12, 2021
ae2cacd
Merge branch 'master' into add_msks
oke-aditya Mar 3, 2021
f5a4636
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Mar 8, 2021
6d8729f
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Mar 19, 2021
bc81e50
convert to float
oke-aditya Mar 19, 2021
32de89b
fixes docs
oke-aditya Mar 20, 2021
9b80cc8
Merge branch 'master' into add_msks
fmassa Mar 22, 2021
acc2d70
Merge branch 'master' into add_msks
fmassa Mar 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ torchvision.utils

.. autofunction:: save_image

.. autofunction:: draw_bounding_boxes
.. autofunction:: draw_bounding_boxes

.. autofunction:: draw_segmentation_masks
33 changes: 33 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,39 @@ def test_draw_boxes(self):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))

def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
# TODO
masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
result = utils.draw_segmentation_masks(img, masks, colors=colors)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))

def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
# TODO
masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
result = utils.draw_segmentation_masks(img, masks, colors=None)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_no_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))


if __name__ == '__main__':
unittest.main()
58 changes: 57 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageColor

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]


@torch.no_grad()
Expand Down Expand Up @@ -210,3 +210,59 @@ def draw_bounding_boxes(
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)


@torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:

"""
Draws segmentation masks on given image.
The values of the input image should be uint8 between 0 and 255.

Args:
image (Tensor): Tensor of shape (C x H x W)
masks (Tensor): Tensor of shape (H, W). Each containing predicted class.
labels (List[str]): List containing the labels of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities.

Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ?

Copy link
Member

@fmassa fmassa Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying)

Copy link
Contributor Author

@oke-aditya oke-aditya Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change.

"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")

img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:])

if colors is None:
# Default color palette assumes 21 classes.
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors_t = torch.as_tensor([i for i in range(21)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")

else:
color_list = []
for color in colors:
if isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color) # + (100,)
color_list.append(fill_color)
elif isinstance(color, tuple):
# fill_color = color + (100,)
# Use the given colors list and create ndarray of colors.
color_list.append(color)

color_arr = np.array(color_list).astype("uint8")

img_to_draw.putpalette(color_arr)
img_to_draw = torch.from_numpy(np.array(img_to_draw))

# Project the drawn image to orignal one
image[: 1] = img_to_draw
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need small help here as this projects to a black background.

My guess is we need an alpha channel which will make masks transparent ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the img_to_draw from palette to RGBA, replace the background colour with transparent and then combine it with image to achieve the "projection". Worth experimenting with the approach because it's likely there is a better way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, DETR has very nice visualization, but they use matplotlib. Unsure how to reproduce them though.
As you pointed out before Mask RCNN utils have nice way to apply mask too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just sent it for reference not necessarily for reproduction. :)

return image