-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added utility to draw segmentation masks #3330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
ba9db12
24f076f
785a6c2
9bb4191
11adde7
d3e28b7
313f6a6
80e4e4a
0388c7a
d6018e8
4af5549
f11fa61
d554c75
301b9de
18b9cf1
02226fa
a720f91
155c568
b5f8e76
8e6df85
ed5de58
731ab33
e582a74
fd61118
b0696a3
2c7bf60
cea79ef
0a7e1fa
58b3870
fbf4dc7
3dbfd67
ae2cacd
f5a4636
6d8729f
bc81e50
32de89b
9b80cc8
acc2d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import numpy as np | ||
from PIL import Image, ImageDraw, ImageFont, ImageColor | ||
|
||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] | ||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] | ||
|
||
|
||
@torch.no_grad() | ||
|
@@ -210,3 +210,59 @@ def draw_bounding_boxes( | |
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) | ||
|
||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) | ||
|
||
|
||
@torch.no_grad() | ||
def draw_segmentation_masks( | ||
image: torch.Tensor, | ||
masks: torch.Tensor, | ||
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, | ||
) -> torch.Tensor: | ||
|
||
""" | ||
Draws segmentation masks on given image. | ||
The values of the input image should be uint8 between 0 and 255. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (C x H x W) | ||
masks (Tensor): Tensor of shape (H, W). Each containing predicted class. | ||
labels (List[str]): List containing the labels of masks. | ||
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can | ||
be represented as `str` or `Tuple[int, int, int]`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities. Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change. |
||
""" | ||
|
||
if not isinstance(image, torch.Tensor): | ||
raise TypeError(f"Tensor expected, got {type(image)}") | ||
elif image.dtype != torch.uint8: | ||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | ||
elif image.dim() != 3: | ||
raise ValueError("Pass individual images, not batches") | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size()[1:]) | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if colors is None: | ||
# Default color palette assumes 21 classes. | ||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | ||
colors_t = torch.as_tensor([i for i in range(21)])[:, None] * palette | ||
color_arr = (colors_t % 255).numpy().astype("uint8") | ||
|
||
else: | ||
color_list = [] | ||
for color in colors: | ||
if isinstance(color, str): | ||
# This will automatically raise Error if rgb cannot be parsed. | ||
fill_color = ImageColor.getrgb(color) # + (100,) | ||
color_list.append(fill_color) | ||
elif isinstance(color, tuple): | ||
# fill_color = color + (100,) | ||
# Use the given colors list and create ndarray of colors. | ||
color_list.append(color) | ||
|
||
color_arr = np.array(color_list).astype("uint8") | ||
|
||
img_to_draw.putpalette(color_arr) | ||
img_to_draw = torch.from_numpy(np.array(img_to_draw)) | ||
|
||
# Project the drawn image to orignal one | ||
image[: 1] = img_to_draw | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need small help here as this projects to a black background. My guess is we need an alpha channel which will make masks transparent ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, DETR has very nice visualization, but they use matplotlib. Unsure how to reproduce them though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just sent it for reference not necessarily for reproduction. :) |
||
return image |
Uh oh!
There was an error while loading. Please reload this page.