Skip to content

Commit 92d75e6

Browse files
pmeiervfdev-5NicolasHug
authored
add gallery example for datapoints (#7321)
Co-authored-by: vfdev <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent d0ad279 commit 92d75e6

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

gallery/plot_datapoints.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
==============
3+
Datapoints FAQ
4+
==============
5+
6+
The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example
7+
showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need
8+
to worry about: you do not need to understand the internals of datapoints to efficiently rely on
9+
``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets,
10+
transforms, or work directly with the datapoints.
11+
"""
12+
13+
import PIL.Image
14+
15+
import torch
16+
import torchvision
17+
18+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
19+
# some APIs may slightly change in the future
20+
torchvision.disable_beta_transforms_warning()
21+
22+
from torchvision import datapoints
23+
24+
25+
########################################################################################################################
26+
# What are datapoints?
27+
# --------------------
28+
#
29+
# Datapoints are zero-copy tensor subclasses:
30+
31+
tensor = torch.rand(3, 256, 256)
32+
image = datapoints.Image(tensor)
33+
34+
assert isinstance(image, torch.Tensor)
35+
assert image.data_ptr() == tensor.data_ptr()
36+
37+
38+
########################################################################################################################
39+
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
40+
# for the input data.
41+
#
42+
# What datapoints are supported?
43+
# ------------------------------
44+
#
45+
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
46+
#
47+
# * :class:`~torchvision.datapoints.Image`
48+
# * :class:`~torchvision.datapoints.Video`
49+
# * :class:`~torchvision.datapoints.BoundingBox`
50+
# * :class:`~torchvision.datapoints.Mask`
51+
#
52+
# How do I construct a datapoint?
53+
# -------------------------------
54+
#
55+
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
56+
57+
image = datapoints.Image([[[[0, 1], [1, 0]]]])
58+
print(image)
59+
60+
61+
########################################################################################################################
62+
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
63+
# parameters.
64+
65+
float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
66+
print(float_image)
67+
68+
69+
########################################################################################################################
70+
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a
71+
# :class:`PIL.Image.Image` directly:
72+
73+
image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
74+
print(image.shape, image.dtype)
75+
76+
########################################################################################################################
77+
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
78+
# :class:`~torchvision.datapoints.BoundingBox` stores the coordinate format as well as the spatial size of the
79+
# corresponding image alongside the actual values:
80+
81+
bounding_box = datapoints.BoundingBox(
82+
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
83+
)
84+
print(bounding_box)
85+
86+
87+
########################################################################################################################
88+
# Do I have to wrap the output of the datasets myself?
89+
# ----------------------------------------------------
90+
#
91+
# Only if you are using custom datasets. For the built-in ones, you can use
92+
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
93+
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
94+
# also don't have to wrap manually.
95+
#
96+
# How do the datapoints behave inside a computation?
97+
# --------------------------------------------------
98+
#
99+
# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor`
100+
# also works on datapoints.
101+
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
102+
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):
103+
104+
assert isinstance(image, datapoints.Image)
105+
106+
new_image = image + 0
107+
108+
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
109+
110+
########################################################################################################################
111+
# .. note::
112+
#
113+
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
114+
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
115+
# https://github.com/pytorch/vision/issues/7319
116+
#
117+
# There are two exceptions to this rule:
118+
#
119+
# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_`
120+
# retain the datapoint type.
121+
# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use
122+
# the flow style, the returned value will be unwrapped:
123+
124+
image = datapoints.Image([[[0, 1], [1, 0]]])
125+
126+
new_image = image.add_(1).mul_(2)
127+
128+
assert isinstance(image, torch.Tensor)
129+
print(image)
130+
131+
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
132+
assert (new_image == image).all()

0 commit comments

Comments
 (0)