Skip to content

Commit b9a704e

Browse files
parmeetdatumboxfmassa
authored andcommitted
[fbsync] Added utility to draw segmentation masks (#3330)
Summary: * add draw segm masks * rewrites with new api * fix flaky colors * fix resize bug * resize for sanity * cleanup * project the image * Minor refactor to adopt num classes * add uint8 in docstring * adds alpha and docstring * move code a bit down * Minor fix * fix type check * Fixing resize bug. * Fix type of alpha. * Remove unnecessary RGBA conversions. * update docs to supported only rgb * minor edits * adds tests * shifts masks up * change tests and impelementation for bool * change mode to L * convert to float * fixes docs Reviewed By: fmassa Differential Revision: D27433933 fbshipit-source-id: 26e72b4f8471218631b26cc555422890b0f6b81d Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent cbd68d7 commit b9a704e

File tree

5 files changed

+116
-3
lines changed

5 files changed

+116
-3
lines changed

docs/source/utils.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ torchvision.utils
77

88
.. autofunction:: save_image
99

10-
.. autofunction:: draw_bounding_boxes
10+
.. autofunction:: draw_bounding_boxes
11+
12+
.. autofunction:: draw_segmentation_masks
Loading
Loading

test/test_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,30 @@
99
import torchvision.transforms.functional as F
1010
from PIL import Image
1111

12+
masks = torch.tensor([
13+
[
14+
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
15+
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
16+
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
17+
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
18+
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
19+
],
20+
[
21+
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
22+
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
23+
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
24+
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
25+
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
26+
],
27+
[
28+
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
29+
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
30+
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
31+
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
32+
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
33+
]
34+
], dtype=torch.float)
35+
1236

1337
class Tester(unittest.TestCase):
1438

@@ -96,6 +120,35 @@ def test_draw_boxes(self):
96120
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
97121
self.assertTrue(torch.equal(result, expected))
98122

123+
def test_draw_segmentation_masks_colors(self):
124+
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
125+
colors = ["#FF00FF", (0, 255, 0), "red"]
126+
result = utils.draw_segmentation_masks(img, masks, colors=colors)
127+
128+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
129+
"fakedata", "draw_segm_masks_colors_util.png")
130+
131+
if not os.path.exists(path):
132+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
133+
res.save(path)
134+
135+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
136+
self.assertTrue(torch.equal(result, expected))
137+
138+
def test_draw_segmentation_masks_no_colors(self):
139+
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
140+
result = utils.draw_segmentation_masks(img, masks, colors=None)
141+
142+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
143+
"fakedata", "draw_segm_masks_no_colors_util.png")
144+
145+
if not os.path.exists(path):
146+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
147+
res.save(path)
148+
149+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
150+
self.assertTrue(torch.equal(result, expected))
151+
99152

100153
if __name__ == '__main__':
101154
unittest.main()

torchvision/utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from PIL import Image, ImageDraw, ImageFont, ImageColor
88

9-
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
9+
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
1010

1111

1212
@torch.no_grad()
@@ -153,7 +153,7 @@ def draw_bounding_boxes(
153153
If filled, Resulting Tensor should be saved as PNG image.
154154
155155
Args:
156-
image (Tensor): Tensor of shape (C x H x W)
156+
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
157157
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
158158
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
159159
`0 <= ymin < ymax < H`.
@@ -210,3 +210,61 @@ def draw_bounding_boxes(
210210
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
211211

212212
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
213+
214+
215+
@torch.no_grad()
216+
def draw_segmentation_masks(
217+
image: torch.Tensor,
218+
masks: torch.Tensor,
219+
alpha: float = 0.2,
220+
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
221+
) -> torch.Tensor:
222+
223+
"""
224+
Draws segmentation masks on given RGB image.
225+
The values of the input image should be uint8 between 0 and 255.
226+
227+
Args:
228+
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
229+
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
230+
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
231+
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
232+
be represented as `str` or `Tuple[int, int, int]`.
233+
"""
234+
235+
if not isinstance(image, torch.Tensor):
236+
raise TypeError(f"Tensor expected, got {type(image)}")
237+
elif image.dtype != torch.uint8:
238+
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
239+
elif image.dim() != 3:
240+
raise ValueError("Pass individual images, not batches")
241+
elif image.size()[0] != 3:
242+
raise ValueError("Pass an RGB image. Other Image formats are not supported")
243+
244+
num_masks = masks.size()[0]
245+
masks = masks.argmax(0)
246+
247+
if colors is None:
248+
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
249+
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
250+
color_arr = (colors_t % 255).numpy().astype("uint8")
251+
else:
252+
color_list = []
253+
for color in colors:
254+
if isinstance(color, str):
255+
# This will automatically raise Error if rgb cannot be parsed.
256+
fill_color = ImageColor.getrgb(color)
257+
color_list.append(fill_color)
258+
elif isinstance(color, tuple):
259+
color_list.append(color)
260+
261+
color_arr = np.array(color_list).astype("uint8")
262+
263+
_, h, w = image.size()
264+
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
265+
img_to_draw.putpalette(color_arr)
266+
267+
img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
268+
img_to_draw = img_to_draw.permute((2, 0, 1))
269+
270+
return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)

0 commit comments

Comments
 (0)