Skip to content

Commit b5486fb

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Rewrite gallery example for masks to boxes. (#4484)
Summary: * start writing example * Update example * Add PenFudan Files * Update example * Remove unused files * Update file and adopt changes * Create links, fix float Reviewed By: datumbox Differential Revision: D31268018 fbshipit-source-id: 71d7f78139ca91334ff2776efd220de31e94eeaf Co-authored-by: Nicolas Hug <[email protected]>
1 parent 3d6b42c commit b5486fb

File tree

5 files changed

+173
-40
lines changed

5 files changed

+173
-40
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _ops:
2+
13
torchvision.ops
24
===============
35

docs/source/utils.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _utils:
2+
13
torchvision.utils
24
=================
35

gallery/assets/FudanPed00054.png

309 KB
Loading

gallery/assets/FudanPed00054_mask.png

2.3 KB
Loading
Lines changed: 169 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,204 @@
11
"""
2-
=======================
3-
Repurposing annotations
4-
=======================
5-
6-
The following example illustrates the operations available in the torchvision.ops module for repurposing object
7-
localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation
2+
=====================================
3+
Repurposing masks into bounding boxes
4+
=====================================
5+
6+
The following example illustrates the operations available
7+
the :ref:`torchvision.ops <ops>` module for repurposing
8+
segmentation masks into object localization annotations for different tasks
9+
(e.g. transforming masks used by instance and panoptic segmentation
810
methods into bounding boxes used by object detection methods).
911
"""
10-
import os.path
1112

12-
import PIL.Image
13-
import matplotlib.patches
14-
import matplotlib.pyplot
15-
import numpy
13+
14+
import os
15+
import numpy as np
1616
import torch
17-
from torchvision.ops import masks_to_boxes
17+
import matplotlib.pyplot as plt
18+
19+
import torchvision.transforms.functional as F
20+
21+
22+
ASSETS_DIRECTORY = "assets"
1823

19-
ASSETS_DIRECTORY = "../test/assets"
24+
plt.rcParams["savefig.bbox"] = "tight"
25+
26+
27+
def show(imgs):
28+
if not isinstance(imgs, list):
29+
imgs = [imgs]
30+
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
31+
for i, img in enumerate(imgs):
32+
img = img.detach()
33+
img = F.to_pil_image(img)
34+
axs[0, i].imshow(np.asarray(img))
35+
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
2036

21-
matplotlib.pyplot.rcParams["savefig.bbox"] = "tight"
2237

2338
####################################
2439
# Masks
2540
# -----
2641
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
2742
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:
2843
#
29-
# (objects, height, width)
44+
# (num_objects, height, width)
3045
#
31-
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
46+
# Where num_objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
3247
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
3348
# of your masks annotation has the following shape:
3449
#
3550
# (4, 224, 224).
3651
#
3752
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
3853
# localization tasks.
39-
#
40-
# Masks to bounding boxes
41-
# ----------------------------------------
42-
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be
43-
# used in methods like Faster RCNN and YOLO.
4454

45-
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
46-
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)
55+
####################################
56+
# Converting Masks to Bounding Boxes
57+
# -----------------------------------------------
58+
# For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to
59+
# transform masks into bounding boxes that can be
60+
# used as input to detection models such as FasterRCNN and RetinaNet.
61+
# We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.
62+
63+
64+
from torchvision.io import read_image
65+
66+
img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
67+
mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
68+
img = read_image(img_path)
69+
mask = read_image(mask_path)
70+
71+
72+
#########################
73+
# Here the masks are represented as a PNG Image, with floating point values.
74+
# Each pixel is encoded as different colors, with 0 being background.
75+
# Notice that the spatial dimensions of image and mask match.
76+
77+
print(mask.size())
78+
print(img.size())
79+
print(mask)
80+
81+
############################
82+
83+
# We get the unique colors, as these would be the object ids.
84+
obj_ids = torch.unique(mask)
85+
86+
# first id is the background, so remove it.
87+
obj_ids = obj_ids[1:]
88+
89+
# split the color-encoded mask into a set of boolean masks.
90+
# Note that this snippet would work as well if the masks were float values instead of ints.
91+
masks = mask == obj_ids[:, None, None]
92+
93+
########################
94+
# Now the masks are a boolean tensor.
95+
# The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image.
96+
# The other two dimensions are height and width, which are equal to the dimensions of the image.
97+
# For each instance, the boolean tensors represent if the particular pixel
98+
# belongs to the segmentation mask of the image.
99+
100+
print(masks.size())
101+
print(masks)
102+
103+
####################################
104+
# Let us visualize an image and plot its corresponding segmentation masks.
105+
# We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks.
106+
107+
from torchvision.utils import draw_segmentation_masks
108+
109+
drawn_masks = []
110+
for mask in masks:
111+
drawn_masks.append(draw_segmentation_masks(img, mask, alpha=0.8, colors="blue"))
112+
113+
show(drawn_masks)
114+
115+
####################################
116+
# To convert the boolean masks into bounding boxes.
117+
# We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module
118+
# It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format.
119+
120+
from torchvision.ops import masks_to_boxes
121+
122+
boxes = masks_to_boxes(masks)
123+
print(boxes.size())
124+
print(boxes)
125+
126+
####################################
127+
# As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format.
128+
# These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility
129+
# provided in :ref:`torchvision.utils <utils>`.
130+
131+
from torchvision.utils import draw_bounding_boxes
132+
133+
drawn_boxes = draw_bounding_boxes(img, boxes, colors="red")
134+
show(drawn_boxes)
135+
136+
###################################
137+
# These boxes can now directly be used by detection models in torchvision.
138+
# Here is demo with a Faster R-CNN model loaded from
139+
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
140+
141+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
142+
143+
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
144+
print(img.size())
145+
146+
img = F.convert_image_dtype(img, torch.float)
147+
target = {}
148+
target["boxes"] = boxes
149+
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
150+
detection_outputs = model(img.unsqueeze(0), [target])
151+
152+
153+
####################################
154+
# Converting Segmentation Dataset to Detection Dataset
155+
# ----------------------------------------------------
156+
#
157+
# With this utility it becomes very simple to convert a segmentation dataset to a detection dataset.
158+
# With this we can now use a segmentation dataset to train a detection model.
159+
# One can similarly convert panoptic dataset to detection dataset.
160+
# Here is an example where we re-purpose the dataset from the
161+
# `PenFudan Detection Tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_.
47162

48-
for index in range(image.n_frames):
49-
image.seek(index)
163+
class SegmentationToDetectionDataset(torch.utils.data.Dataset):
164+
def __init__(self, root, transforms):
165+
self.root = root
166+
self.transforms = transforms
167+
# load all image files, sorting them to
168+
# ensure that they are aligned
169+
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
170+
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
50171

51-
frame = numpy.array(image)
172+
def __getitem__(self, idx):
173+
# load images and masks
174+
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
175+
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
52176

53-
masks[index] = torch.tensor(frame)
177+
img = read_image(img_path)
178+
mask = read_image(mask_path)
54179

55-
bounding_boxes = masks_to_boxes(masks)
180+
img = F.convert_image_dtype(img, dtype=torch.float)
181+
mask = F.convert_image_dtype(mask, dtype=torch.float)
56182

57-
figure = matplotlib.pyplot.figure()
183+
# We get the unique colors, as these would be the object ids.
184+
obj_ids = torch.unique(mask)
58185

59-
a = figure.add_subplot(121)
60-
b = figure.add_subplot(122)
186+
# first id is the background, so remove it.
187+
obj_ids = obj_ids[1:]
61188

62-
labeled_image = torch.sum(masks, 0)
189+
# split the color-encoded mask into a set of boolean masks.
190+
masks = mask == obj_ids[:, None, None]
63191

64-
a.imshow(labeled_image)
65-
b.imshow(labeled_image)
192+
boxes = masks_to_boxes(masks)
66193

67-
for bounding_box in bounding_boxes:
68-
x0, y0, x1, y1 = bounding_box
194+
# there is only one class
195+
labels = torch.ones((masks.shape[0],), dtype=torch.int64)
69196

70-
rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none")
197+
target = {}
198+
target["boxes"] = boxes
199+
target["labels"] = labels
71200

72-
b.add_patch(rectangle)
201+
if self.transforms is not None:
202+
img, target = self.transforms(img, target)
73203

74-
a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
75-
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
204+
return img, target

0 commit comments

Comments
 (0)