Skip to content

Added utility to draw segmentation masks #3330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba9db12
add draw segm masks
oke-aditya Jan 30, 2021
24f076f
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Feb 1, 2021
785a6c2
Merge branch 'master' into add_msks
datumbox Feb 1, 2021
9bb4191
rewrites with new api
oke-aditya Feb 2, 2021
11adde7
Merge branch 'add_msks' of github.com:oke-aditya/vision into add_msks
oke-aditya Feb 2, 2021
d3e28b7
fix flaky colors
oke-aditya Feb 2, 2021
313f6a6
fix resize bug
oke-aditya Feb 2, 2021
80e4e4a
resize for sanity
oke-aditya Feb 2, 2021
0388c7a
cleanup
oke-aditya Feb 2, 2021
d6018e8
project the image
oke-aditya Feb 2, 2021
4af5549
Minor refactor to adopt num classes
oke-aditya Feb 4, 2021
f11fa61
add uint8 in docstring
oke-aditya Feb 4, 2021
d554c75
adds alpha and docstring
oke-aditya Feb 6, 2021
301b9de
move code a bit down
oke-aditya Feb 6, 2021
18b9cf1
Merge branch 'master' into add_msks
oke-aditya Feb 6, 2021
02226fa
Minor fix
oke-aditya Feb 8, 2021
a720f91
fix type check
oke-aditya Feb 8, 2021
155c568
Fixing resize bug.
datumbox Feb 8, 2021
b5f8e76
Merge branch 'master' into add_msks
datumbox Feb 8, 2021
8e6df85
Fix type of alpha.
datumbox Feb 8, 2021
ed5de58
Remove unnecessary RGBA conversions.
datumbox Feb 8, 2021
731ab33
update docs to supported only rgb
oke-aditya Feb 9, 2021
e582a74
minor edits
oke-aditya Feb 9, 2021
fd61118
adds tests
oke-aditya Feb 9, 2021
b0696a3
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Feb 9, 2021
2c7bf60
shifts masks up
oke-aditya Feb 9, 2021
cea79ef
Merge branch 'master' into add_msks
datumbox Feb 10, 2021
0a7e1fa
Merge branch 'master' into add_msks
datumbox Feb 10, 2021
58b3870
change tests and impelementation for bool
oke-aditya Feb 10, 2021
fbf4dc7
change mode to L
oke-aditya Feb 10, 2021
3dbfd67
Merge branch 'master' into add_msks
oke-aditya Feb 12, 2021
ae2cacd
Merge branch 'master' into add_msks
oke-aditya Mar 3, 2021
f5a4636
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Mar 8, 2021
6d8729f
Merge branch 'master' of https://github.com/pytorch/vision into add_msks
oke-aditya Mar 19, 2021
bc81e50
convert to float
oke-aditya Mar 19, 2021
32de89b
fixes docs
oke-aditya Mar 20, 2021
9b80cc8
Merge branch 'master' into add_msks
fmassa Mar 22, 2021
acc2d70
Merge branch 'master' into add_msks
fmassa Mar 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ torchvision.utils

.. autofunction:: save_image

.. autofunction:: draw_bounding_boxes
.. autofunction:: draw_bounding_boxes

.. autofunction:: draw_segmentation_masks
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 53 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,30 @@
import torchvision.transforms.functional as F
from PIL import Image

masks = torch.tensor([
[
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
],
[
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
],
[
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
]
], dtype=torch.float)


class Tester(unittest.TestCase):

Expand Down Expand Up @@ -96,6 +120,35 @@ def test_draw_boxes(self):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))

def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))

def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
result = utils.draw_segmentation_masks(img, masks, colors=None)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_no_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))


if __name__ == '__main__':
unittest.main()
62 changes: 60 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageColor

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]


@torch.no_grad()
Expand Down Expand Up @@ -153,7 +153,7 @@ def draw_bounding_boxes(
If filled, Resulting Tensor should be saved as PNG image.

Args:
image (Tensor): Tensor of shape (C x H x W)
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
Expand Down Expand Up @@ -210,3 +210,61 @@ def draw_bounding_boxes(
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)


@torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.2,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:

"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.

Args:
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities.

Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ?

Copy link
Member

@fmassa fmassa Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying)

Copy link
Contributor Author

@oke-aditya oke-aditya Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change.

"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")

num_masks = masks.size()[0]
masks = masks.argmax(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to unblock you can do masks.to(torch.int64).argmax(0) or cast it to float if you want.
This won't handle overlapping instances very well thought, and we will need to remove this in the future and probably replace it with a for loop so that overlapping masks are taken into account.

Plus, by using for loops and letting the mask be a floating point if the user wants, we can allow the user to have heatmaps being passed (instead of only binary maps), which would be very nice


if colors is None:
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")
else:
color_list = []
for color in colors:
if isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color)
color_list.append(fill_color)
elif isinstance(color, tuple):
color_list.append(color)

color_arr = np.array(color_list).astype("uint8")

_, h, w = image.size()
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
img_to_draw.putpalette(color_arr)

img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
img_to_draw = img_to_draw.permute((2, 0, 1))

return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)