-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added utility to draw segmentation masks #3330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchvision/utils.py
Outdated
img_to_draw = torch.from_numpy(np.array(img_to_draw)) | ||
|
||
# Project the drawn image to orignal one | ||
image[: 1] = img_to_draw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need small help here as this projects to a black background.
My guess is we need an alpha channel which will make masks transparent ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the img_to_draw
from palette to RGBA, replace the background colour with transparent and then combine it with image
to achieve the "projection". Worth experimenting with the approach because it's likely there is a better way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, DETR has very nice visualization, but they use matplotlib. Unsure how to reproduce them though.
As you pointed out before Mask RCNN utils have nice way to apply mask too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just sent it for reference not necessarily for reproduction. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oke-aditya It's going towards the right direction. I left a few comments to get your thoughts.
BTW @fmassa brought to my attention a nice guide they have to DETR with some visualization utils we might want to look into for inspiration.
torchvision/utils.py
Outdated
img_to_draw = torch.from_numpy(np.array(img_to_draw)) | ||
|
||
# Project the drawn image to orignal one | ||
image[: 1] = img_to_draw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I supposed there are many ways to do it. One that I had in mind originally (which might not be the optimal) is to convert the img_to_draw
from palette to RGBA, replace the background colour with transparent and then combine it with image
to achieve the "projection". Worth experimenting with the approach because it's likely there is a better way.
Hi @datumbox I just resolved the |
I'm bit unsure now how to proceed, one way I think is to create another util to apply weighted mask just like MaskRCNN matterpott benchmark We can probably use this and apply the mask here. Unsure of tests, it might be too lengthy and calculative to write 20 x 20 x 3 tensor with probabilities. I thought to just run FCN on Let me know how to proceed ! This part seems tricky This does double check, but couples tests with models. |
@oke-aditya Great changes, I think we are almost there! As I commented above, there are multiple ways to do this. Here is a hacky, quick and dirty approach to get an idea. I'm sure you can do it in a much better way: @torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
classifications = masks.argmax(0).byte()
img_to_draw = Image.fromarray(classifications.cpu().numpy())
if colors is None:
num_classes = masks.size(0)
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1, 0])
colors_t = torch.as_tensor([i for i in range(num_classes)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")
color_arr[1:, 3] = 255
else:
color_list = []
for color in colors:
if isinstance(color, str):
fill_color = ImageColor.getrgb(color)
color_list.append(fill_color)
elif isinstance(color, tuple):
color_list.append(color)
color_arr = np.array(color_list).astype("uint8")
img_to_draw.putpalette(color_arr)
img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGBA')))
img_to_draw = img_to_draw.permute((2, 0, 1))
alpha = 0.6
return (torch.cat([image, torch.full(image.shape[1:], 255).unsqueeze(0)]).float()*alpha+img_to_draw.float()*(1.0-alpha)).to(dtype=torch.uint8) Ping me when you have a demo you are comfortable with to discuss the last details of the API. :) |
Sorry for the delay @datumbox . Here are few ouputs for different values of alpha Alpha = 0.2 Alpha = 0.3 Alpha = 0.6 Alpha = 0.7 Also I made Another thought I had was to make util
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
The PR looks very nice, thanks a lot!
I have one suggestion which I think would make the function more generic, and wouldn't involve too many changes. Let me know what you think
torchvision/utils.py
Outdated
num_classes = masks.size()[0] | ||
masks = masks.argmax(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of being specific for semantic segmentation and not supporting instance segmentation nor panoptic segmentation, I think we could make this slightly more generic while supporting all the use-cases I mentioned. The idea would be to accept a mask as a [num_masks, H, W]
boolean Tensor.
This way, the user can get the semantic segmentation masks to pass to this function as follows
out.argmax(0) == torch.arange(out.shape[0])[:, None, None]
masks (Tensor): Tensor of shape (num_classes, H, W). Each containing probability of predicted class. | ||
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks. | ||
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can | ||
be represented as `str` or `Tuple[int, int, int]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add support for instance segmentation and panoptic segmentation, I think it would be a good idea to add an example from using the output of a semantic segmentation model and an instance segmentation model (for example from those from torchvision)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this is a TODO. I will add GitHub gist and other minor documentation improvements for both the utilities.
Can I do this in a follow-up PR which will address all the issues as mentioned in #3364 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, adding documentation improvements on a follow-up PR is ok with me. What do you think about the other comment as well? Because it would be a breaking change in functionality if we support it, so better do it once (specially that the branch cut is happening very soon so if we merge it now it can get integrated in the release, in which case breaking backwards-compatibility is more annoying)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definetely, I will refactor the other comment ASAP in this PR 😄 I understand how bad it would be with BC change.
Slight problem while implementing. I am pushing my latest changes, can someone please have look ? A simple code to reproduce bug
Maybe there is some workaround? I guess this was the only change, to support The error would occur when we try to get |
@oke-aditya maybe I'm missing something, but if we pass a bool tensor we don't need to compute the argmax anymore, because the independent masks have already been computed? EDIT: oh I see, you don't perform a for loop over each one of the masks as of now. Computing the argmax could be done after casting the mask to float for example, but note that in instance segmentation each pixel can be covered by multiple masks, so the argmax wouldn't be enough to handle those. But it can be a first approximation |
That's possible. But may I know why we decided to pass a Bool Tensor and not a Float mask Tensor ? Edit: My idea was to make initial implementation compatible with single channel masks. That would make it compatible with Mask RCNN. I thought in the same previous function, we could handle single channel case differently. |
I think with release cut coming quite soon, we could wait for and add in next release ? (I mean after 0.9.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in reviewing.
I've made a comment to unblock you, but I think in the future we might need to change the implementation to let it handle overlapping masks (which is the whole purpose of the function accepting a tensor with a per-object map).
The input mask doesn't need to be a boolean by the way, but it can be a floating point representing probabilities for the given instance.
But to unblock for now let's just make this small change I proposed, and we can improve this in a follow-up PR
raise ValueError("Pass an RGB image. Other Image formats are not supported") | ||
|
||
num_masks = masks.size()[0] | ||
masks = masks.argmax(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to unblock you can do masks.to(torch.int64).argmax(0)
or cast it to float if you want.
This won't handle overlapping instances very well thought, and we will need to remove this in the future and probably replace it with a for loop so that overlapping masks are taken into account.
Plus, by using for loops and letting the mask be a floating point if the user wants, we can allow the user to have heatmaps being passed (instead of only binary maps), which would be very nice
Extremely sorry for the delay (my health let me down 😢) I think that boolean tensor leads to some limitations and again we re-cast it to int/float. I refactored to use floating-point masks. Each point represents the probability of class. Currently, this code accepts a floating-point tensor of Let me know what we need to do in further PRs / this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Summary: * add draw segm masks * rewrites with new api * fix flaky colors * fix resize bug * resize for sanity * cleanup * project the image * Minor refactor to adopt num classes * add uint8 in docstring * adds alpha and docstring * move code a bit down * Minor fix * fix type check * Fixing resize bug. * Fix type of alpha. * Remove unnecessary RGBA conversions. * update docs to supported only rgb * minor edits * adds tests * shifts masks up * change tests and impelementation for bool * change mode to L * convert to float * fixes docs Reviewed By: fmassa Differential Revision: D27433933 fbshipit-source-id: 26e72b4f8471218631b26cc555422890b0f6b81d Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
Closes #3272.
Initial Implementation