Skip to content

Enable one-hot-encoded labels in MixUp and CutMix #8427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,26 +2169,30 @@ def test_image_correctness(self, brightness_factor):

class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, num_classes):
def __init__(self, size, num_classes, one_hot_labels):
self.size = size
self.num_classes = num_classes
self.one_hot_labels = one_hot_labels
assert size < num_classes

def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
if self.one_hot_labels:
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
return img, label

def __len__(self):
return self.size

@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_supported_input_structure(self, T):
@pytest.mark.parametrize("one_hot_labels", (True, False))
def test_supported_input_structure(self, T, one_hot_labels):

batch_size = 32
num_classes = 100

dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)

cutmix_mixup = T(num_classes=num_classes)

Expand All @@ -2198,7 +2202,7 @@ def test_supported_input_structure(self, T):
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,)
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
Expand All @@ -2209,7 +2213,7 @@ def check_output(img, target):

# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size,)
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)

Expand Down Expand Up @@ -2264,30 +2268,29 @@ def test_error(self, T):
with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})

with pytest.raises(ValueError, match="labels tensor should be of shape"):
with pytest.raises(ValueError, match="labels should be index based"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)

with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")

with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))

with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))

with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# The purpose of this check is more about documenting the current
# behaviour of what happens on a Compose(), rather than actually
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
labels = torch.randint(0, num_classes, size=(batch_size,))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
with pytest.raises(ValueError, match="When passing 2D labels"):
wrong_num_classes = num_classes + 1
T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))

with pytest.raises(ValueError, match="but got a tensor of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))

with pytest.raises(ValueError, match="num_classes must be passed"):
T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
Expand Down
28 changes: 21 additions & 7 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
Expand All @@ -162,10 +162,21 @@ def forward(self, *inputs):
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1:
if labels.ndim not in (1, 2):
raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
f"labels should be index based with shape (batch_size,) "
f"or probability based with shape (batch_size, num_classes), "
f"but got a tensor of shape {labels.shape} instead."
)
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
raise ValueError(
f"When passing 2D labels, "
f"the number of elements in last dimension must match num_classes: "
f"{labels.shape[-1]} != {self.num_classes}. "
f"You can Leave num_classes to None."
)
if labels.ndim == 1 and self.num_classes is None:
raise ValueError("num_classes must be passed if the labels are index-based (1D)")

params = {
"labels": labels,
Expand Down Expand Up @@ -198,7 +209,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
)

def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes)
if label.ndim == 1:
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
Expand All @@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):

Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
Expand Down Expand Up @@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):

Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
Expand Down