-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added utility to draw segmentation masks #3330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba9db12
24f076f
785a6c2
9bb4191
11adde7
d3e28b7
313f6a6
80e4e4a
0388c7a
d6018e8
4af5549
f11fa61
d554c75
301b9de
18b9cf1
02226fa
a720f91
155c568
b5f8e76
8e6df85
ed5de58
731ab33
e582a74
fd61118
b0696a3
2c7bf60
cea79ef
0a7e1fa
58b3870
fbf4dc7
3dbfd67
ae2cacd
f5a4636
6d8729f
bc81e50
32de89b
9b80cc8
acc2d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import numpy as np | ||
from PIL import Image, ImageDraw, ImageFont, ImageColor | ||
|
||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] | ||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] | ||
|
||
|
||
@torch.no_grad() | ||
|
@@ -153,7 +153,7 @@ def draw_bounding_boxes( | |
If filled, Resulting Tensor should be saved as PNG image. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (C x H x W) | ||
image (Tensor): Tensor of shape (C x H x W) and dtype uint8. | ||
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that | ||
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and | ||
`0 <= ymin < ymax < H`. | ||
|
@@ -210,3 +210,61 @@ def draw_bounding_boxes( | |
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) | ||
|
||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) | ||
|
||
|
||
@torch.no_grad() | ||
def draw_segmentation_masks( | ||
image: torch.Tensor, | ||
masks: torch.Tensor, | ||
alpha: float = 0.2, | ||
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, | ||
) -> torch.Tensor: | ||
|
||
""" | ||
Draws segmentation masks on given RGB image. | ||
The values of the input image should be uint8 between 0 and 255. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. | ||
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. | ||
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. | ||
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can | ||
be represented as `str` or `Tuple[int, int, int]`. | ||
""" | ||
|
||
if not isinstance(image, torch.Tensor): | ||
raise TypeError(f"Tensor expected, got {type(image)}") | ||
elif image.dtype != torch.uint8: | ||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | ||
elif image.dim() != 3: | ||
raise ValueError("Pass individual images, not batches") | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif image.size()[0] != 3: | ||
raise ValueError("Pass an RGB image. Other Image formats are not supported") | ||
|
||
num_masks = masks.size()[0] | ||
masks = masks.argmax(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to unblock you can do Plus, by using for loops and letting the mask be a floating point if the user wants, we can allow the user to have heatmaps being passed (instead of only binary maps), which would be very nice |
||
|
||
if colors is None: | ||
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | ||
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette | ||
color_arr = (colors_t % 255).numpy().astype("uint8") | ||
else: | ||
color_list = [] | ||
for color in colors: | ||
if isinstance(color, str): | ||
# This will automatically raise Error if rgb cannot be parsed. | ||
fill_color = ImageColor.getrgb(color) | ||
color_list.append(fill_color) | ||
elif isinstance(color, tuple): | ||
color_list.append(color) | ||
|
||
color_arr = np.array(color_list).astype("uint8") | ||
|
||
_, h, w = image.size() | ||
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) | ||
img_to_draw.putpalette(color_arr) | ||
|
||
img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
img_to_draw = img_to_draw.permute((2, 0, 1)) | ||
|
||
return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities.
Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change.