Skip to content

Commit 276701f

Browse files
authored
Merge branch 'main' into canvas-size
2 parents 893414b + bdf1622 commit 276701f

File tree

9 files changed

+190
-24
lines changed

9 files changed

+190
-24
lines changed

.git-blame-ignore-revs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ d367a01a18a3ae6bee13d8be3b63fd6a581ea46f
99
6ca9c76adb6daf2695d603ad623a9cf1c4f4806f
1010
# Fix unnecessary exploded black formatting (#7709)
1111
a335d916db0694770e8152f41e19195de3134523
12+
# Renaming: `BoundingBox` -> `BoundingBoxes` (#7778)
13+
332bff937c6711666191880fab57fa2f23ae772e

docs/source/transforms.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
261261
AugMix
262262
v2.AugMix
263263

264-
Cutmix - Mixup
264+
CutMix - MixUp
265265
--------------
266266

267-
Cutmix and Mixup are special transforms that
267+
CutMix and MixUp are special transforms that
268268
are meant to be used on batches rather than on individual images, because they
269-
are combining pairs of images together. These can be used after the dataloader,
270-
or part of a collation function. See
269+
are combining pairs of images together. These can be used after the dataloader
270+
(once the samples are batched), or part of a collation function. See
271271
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
272272

273273
.. autosummary::

gallery/plot_cutmix_mixup.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,152 @@
11

22
"""
33
===========================
4-
How to use Cutmix and Mixup
4+
How to use CutMix and MixUp
55
===========================
66
7-
TODO
7+
:class:`~torchvision.transforms.v2.Cutmix` and
8+
:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies
9+
that can improve classification accuracy.
10+
11+
These transforms are slightly different from the rest of the Torchvision
12+
transforms, because they expect
13+
**batches** of samples as input, not individual images. In this example we'll
14+
explain how to use them: after the ``DataLoader``, or as part of a collation
15+
function.
816
"""
17+
18+
# %%
19+
import torch
20+
import torchvision
21+
from torchvision.datasets import FakeData
22+
23+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
24+
# some APIs may slightly change in the future
25+
torchvision.disable_beta_transforms_warning()
26+
27+
from torchvision.transforms import v2
28+
29+
30+
NUM_CLASSES = 100
31+
32+
# %%
33+
# Pre-processing pipeline
34+
# -----------------------
35+
#
36+
# We'll use a simple but typical image classification pipeline:
37+
38+
preproc = v2.Compose([
39+
v2.PILToTensor(),
40+
v2.RandomResizedCrop(size=(224, 224), antialias=True),
41+
v2.RandomHorizontalFlip(p=0.5),
42+
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
43+
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
44+
])
45+
46+
dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)
47+
48+
img, label = dataset[0]
49+
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
50+
51+
# %%
52+
#
53+
# One important thing to note is that neither CutMix nor MixUp are part of this
54+
# pre-processing pipeline. We'll add them a bit later once we define the
55+
# DataLoader. Just as a refresher, this is what the DataLoader and training loop
56+
# would look like if we weren't using CutMix or MixUp:
57+
58+
from torch.utils.data import DataLoader
59+
60+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
61+
62+
for images, labels in dataloader:
63+
print(f"{images.shape = }, {labels.shape = }")
64+
print(labels.dtype)
65+
# <rest of the training loop here>
66+
break
67+
# %%
68+
69+
# %%
70+
# Where to use MixUp and CutMix
71+
# -----------------------------
72+
#
73+
# After the DataLoader
74+
# ^^^^^^^^^^^^^^^^^^^^
75+
#
76+
# Now let's add CutMix and MixUp. The simplest way to do this right after the
77+
# DataLoader: the Dataloader has already batched the images and labels for us,
78+
# and this is exactly what these transforms expect as input:
79+
80+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
81+
82+
cutmix = v2.Cutmix(num_classes=NUM_CLASSES)
83+
mixup = v2.Mixup(num_classes=NUM_CLASSES)
84+
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
85+
86+
for images, labels in dataloader:
87+
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
88+
images, labels = cutmix_or_mixup(images, labels)
89+
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")
90+
91+
# <rest of the training loop here>
92+
break
93+
# %%
94+
#
95+
# Note how the labels were also transformed: we went from a batched label of
96+
# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
97+
# transformed labels can still be passed as-is to a loss function like
98+
# :func:`torch.nn.functional.cross_entropy`.
99+
#
100+
# As part of the collation function
101+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
102+
#
103+
# Passing the transforms after the DataLoader is the simplest way to use CutMix
104+
# and MixUp, but one disadvantage is that it does not take advantage of the
105+
# DataLoader multi-processing. For that, we can pass those transforms as part of
106+
# the collation function (refer to the `PyTorch docs
107+
# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
108+
# more about collation).
109+
110+
from torch.utils.data import default_collate
111+
112+
113+
def collate_fn(batch):
114+
return cutmix_or_mixup(*default_collate(batch))
115+
116+
117+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
118+
119+
for images, labels in dataloader:
120+
print(f"{images.shape = }, {labels.shape = }")
121+
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
122+
# <rest of the training loop here>
123+
break
124+
125+
# %%
126+
# Non-standard input format
127+
# -------------------------
128+
#
129+
# So far we've used a typical sample structure where we pass ``(images,
130+
# labels)`` as inputs. MixUp and CutMix will magically work by default with most
131+
# common sample structures: tuples where the second parameter is a tensor label,
132+
# or dict with a "label[s]" key. Look at the documentation of the
133+
# ``labels_getter`` parameter for more details.
134+
#
135+
# If your samples have a different structure, you can still use CutMix and MixUp
136+
# by passing a callable to the ``labels_getter`` parameter. For example:
137+
138+
batch = {
139+
"imgs": torch.rand(4, 3, 224, 224),
140+
"target": {
141+
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
142+
"some_other_key": "this is going to be passed-through"
143+
}
144+
}
145+
146+
147+
def labels_getter(batch):
148+
return batch["target"]["classes"]
149+
150+
151+
out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
152+
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")

references/detection/coco_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
178178
break
179179
if isinstance(dataset, torch.utils.data.Subset):
180180
dataset = dataset.dataset
181-
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
182-
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
183-
):
181+
if isinstance(dataset, torchvision.datasets.CocoDetection):
184182
return dataset.coco
185183
return convert_to_coco_api(dataset)
186184

references/detection/group_by_aspect_ratio.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
164164
if hasattr(dataset, "get_height_and_width"):
165165
return _compute_aspect_ratios_custom_dataset(dataset, indices)
166166

167-
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
168-
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
169-
):
167+
if isinstance(dataset, torchvision.datasets.CocoDetection):
170168
return _compute_aspect_ratios_coco_dataset(dataset, indices)
171169

172170
if isinstance(dataset, torchvision.datasets.VOCDetection):

test/datasets_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def test_transforms_v2_wrapper(self, config):
571571
from torchvision.datasets import wrap_dataset_for_transforms_v2
572572

573573
try:
574-
with self.create_dataset(config) as (dataset, _):
574+
with self.create_dataset(config) as (dataset, info):
575575
for target_keys in [None, "all"]:
576576
if target_keys is not None and self.DATASET_CLASS not in {
577577
torchvision.datasets.CocoDetection,
@@ -584,8 +584,10 @@ def test_transforms_v2_wrapper(self, config):
584584
continue
585585

586586
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
587-
wrapped_sample = wrapped_dataset[0]
587+
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
588+
assert len(wrapped_dataset) == info["num_examples"]
588589

590+
wrapped_sample = wrapped_dataset[0]
589591
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
590592
except TypeError as error:
591593
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"

test/test_transforms_v2_refactored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1922,7 +1922,7 @@ def test_supported_input_structure(self, T):
19221922

19231923
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
19241924

1925-
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
1925+
cutmix_mixup = T(num_classes=num_classes)
19261926

19271927
dl = DataLoader(dataset, batch_size=batch_size)
19281928

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from collections import defaultdict
99

1010
import torch
11-
from torch.utils.data import Dataset
1211

1312
from torchvision import datapoints, datasets
1413
from torchvision.transforms.v2 import functional as F
@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
9897
f"but got {target_keys}"
9998
)
10099

101-
return VisionDatasetDatapointWrapper(dataset, target_keys)
100+
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
101+
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
102+
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
103+
# while we can still inject everything that we need.
104+
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
105+
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
106+
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
107+
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
108+
# have the existing instance as attribute on the new object.
109+
return wrapped_dataset_cls(dataset, target_keys)
102110

103111

104112
class WrapperFactories(dict):
@@ -117,7 +125,7 @@ def decorator(wrapper_factory):
117125
WRAPPER_FACTORIES = WrapperFactories()
118126

119127

120-
class VisionDatasetDatapointWrapper(Dataset):
128+
class VisionDatasetDatapointWrapper:
121129
def __init__(self, dataset, target_keys):
122130
dataset_cls = type(dataset)
123131

torchvision/transforms/v2/_augment.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def _transform(
141141

142142

143143
class _BaseMixupCutmix(Transform):
144-
def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None:
144+
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
145145
super().__init__()
146-
self.alpha = alpha
146+
self.alpha = float(alpha)
147147
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
148148

149149
self.num_classes = num_classes
@@ -204,13 +204,20 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
204204

205205

206206
class Mixup(_BaseMixupCutmix):
207-
"""[BETA] Apply Mixup to the provided batch of images and labels.
207+
"""[BETA] Apply MixUp to the provided batch of images and labels.
208208
209209
.. v2betastatus:: Mixup transform
210210
211211
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
212212
213-
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
213+
.. note::
214+
This transform is meant to be used on **batches** of samples, not
215+
individual images. See
216+
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
217+
examples.
218+
The sample pairing is deterministic and done by matching consecutive
219+
samples in the batch, so the batch needs to be shuffled (this is an
220+
implementation detail, not a guaranteed convention.)
214221
215222
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
216223
into a tensor of shape ``(batch_size, num_classes)``.
@@ -246,14 +253,21 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
246253

247254

248255
class Cutmix(_BaseMixupCutmix):
249-
"""[BETA] Apply Cutmix to the provided batch of images and labels.
256+
"""[BETA] Apply CutMix to the provided batch of images and labels.
250257
251258
.. v2betastatus:: Cutmix transform
252259
253260
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
254261
<https://arxiv.org/abs/1905.04899>`_.
255262
256-
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
263+
.. note::
264+
This transform is meant to be used on **batches** of samples, not
265+
individual images. See
266+
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
267+
examples.
268+
The sample pairing is deterministic and done by matching consecutive
269+
samples in the batch, so the batch needs to be shuffled (this is an
270+
implementation detail, not a guaranteed convention.)
257271
258272
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
259273
into a tensor of shape ``(batch_size, num_classes)``.

0 commit comments

Comments
 (0)