Skip to content

Commit 9bb4191

Browse files
committed
rewrites with new api
1 parent 24f076f commit 9bb4191

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

test/test_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,32 @@ def test_draw_boxes(self):
9696
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
9797
self.assertTrue(torch.equal(result, expected))
9898

99-
def test_draw_segmentation_masks(self):
99+
def test_draw_segmentation_masks_colors(self):
100100
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
101101
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
102-
labels = ["a", "b", "c", "d"]
103-
boxes = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]])
104-
result = utils.draw_segmentation_masks(img, boxes, labels=labels, colors=colors)
102+
# TODO
103+
masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
104+
result = utils.draw_segmentation_masks(img, masks, colors=colors)
105+
106+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
107+
"fakedata", "draw_segm_masks_colors_util.png")
108+
109+
if not os.path.exists(path):
110+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
111+
res.save(path)
112+
113+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
114+
self.assertTrue(torch.equal(result, expected))
115+
116+
def test_draw_segmentation_masks_no_colors(self):
117+
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
118+
# TODO
119+
masks = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
120+
result = utils.draw_segmentation_masks(img, masks, colors=None)
121+
122+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
123+
"fakedata", "draw_segm_masks_no_colors_util.png")
105124

106-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_segm_masks_util.png")
107125
if not os.path.exists(path):
108126
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
109127
res.save(path)

torchvision/utils.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,7 @@ def draw_bounding_boxes(
216216
def draw_segmentation_masks(
217217
image: torch.Tensor,
218218
masks: torch.Tensor,
219-
labels: Optional[List[str]] = None,
220219
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
221-
font: Optional[str] = None,
222-
font_size: int = 10
223220
) -> torch.Tensor:
224221

225222
"""
@@ -232,10 +229,6 @@ def draw_segmentation_masks(
232229
labels (List[str]): List containing the labels of masks.
233230
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
234231
be represented as `str` or `Tuple[int, int, int]`.
235-
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
236-
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
237-
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
238-
font_size (int): The requested font size in points.
239232
"""
240233

241234
if not isinstance(image, torch.Tensor):
@@ -245,21 +238,26 @@ def draw_segmentation_masks(
245238
elif image.dim() != 3:
246239
raise ValueError("Pass individual images, not batches")
247240

248-
ndarr = image.permute(1, 2, 0).numpy()
249-
img_to_draw = Image.fromarray(ndarr)
241+
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize(image.size)
250242

251-
img_preds = masks.to(torch.int64).tolist()
252-
253-
draw = ImageDraw.Draw(img_to_draw)
254-
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
243+
if colors is None:
244+
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
245+
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
246+
color_arr = (colors % 255).numpy().astype("uint8")
255247

256-
for i in range(len(img_preds)):
257-
for j in range(len(img_preds)):
258-
draw.point((i, j), fill=colors[img_preds[i][j]])
248+
else:
249+
color_list = []
250+
for color in colors:
251+
if isinstance(color, str):
252+
# This will automatically raise Error if rgb cannot be parsed.
253+
fill_color = ImageColor.getrgb(color) # + (100,)
254+
color_list.append(fill_color)
255+
elif isinstance(color, tuple):
256+
# fill_color = color + (100,)
257+
# Use the given colors list and create ndarray of colors.
258+
color_list.append(color)
259259

260-
if labels is not None:
261-
# Should we plot the text ?
262-
# draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
263-
pass
260+
color_arr = np.array(color_list).astype("uint8")
264261

262+
img_to_draw.putpalette(color_arr)
265263
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)

0 commit comments

Comments
 (0)