From a304c723b290bb87c442517c8ca590c4f68de18b Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Sat, 18 May 2024 18:19:05 +0100 Subject: [PATCH 1/4] Enable pre-encoded mixup --- torchvision/transforms/v2/_augment.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index cc645d6c8a8..96a9a526df3 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class _BaseMixUpCutMix(Transform): - def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", labels_encoded: bool = False) -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) @@ -150,6 +150,7 @@ def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="defau self.num_classes = num_classes self._labels_getter = _parse_labels_getter(labels_getter) + self._labels_encoded = labels_encoded def forward(self, *inputs): inputs = inputs if len(inputs) > 1 else inputs[0] @@ -162,9 +163,9 @@ def forward(self, *inputs): labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") - elif labels.ndim != 1: + elif (not self._labels_encoded and labels.ndim != 1) or (self._labels_encoded and labels.ndim != 2): raise ValueError( - f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." + f"labels tensor should be of shape (batch_size,{self.num_classes if self._labels_encoded else ''}) " f"but got shape {labels.shape} instead." ) params = { @@ -198,7 +199,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): ) def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: - label = one_hot(label, num_classes=self.num_classes) + if not self._labels_encoded: + label = one_hot(label, num_classes=self.num_classes) if not label.dtype.is_floating_point: label = label.float() return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) From e7e2fe400709b655497238cc422f7469bc68b713 Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Mon, 20 May 2024 15:16:23 +0100 Subject: [PATCH 2/4] Update for comments from NH --- test/test_transforms_v2.py | 23 ++++++++++++++++------- torchvision/transforms/v2/_augment.py | 7 ++++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 24574eb1a43..190b590c89f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2169,26 +2169,29 @@ def test_image_correctness(self, brightness_factor): class TestCutMixMixUp: class DummyDataset: - def __init__(self, size, num_classes): + def __init__(self, size, num_classes, encode_labels:bool): self.size = size self.num_classes = num_classes + self.encode_labels = encode_labels assert size < num_classes def __getitem__(self, idx): img = torch.rand(3, 100, 100) - label = idx # This ensures all labels in a batch are unique and makes testing easier + label = torch.tensor(idx) # This ensures all labels in a batch are unique and makes testing easier + if self.encode_labels: + label = torch.nn.functional.one_hot(label, num_classes=self.num_classes) return img, label def __len__(self): return self.size - @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) - def test_supported_input_structure(self, T): + @pytest.mark.parametrize(["T", "encode_labels"], [[transforms.CutMix, False], [transforms.MixUp, False], [transforms.CutMix, True], [transforms.MixUp, True]]) + def test_supported_input_structure(self, T, encode_labels: bool): batch_size = 32 num_classes = 100 - dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) + dataset = self.DummyDataset(size=batch_size, num_classes=num_classes,encode_labels=encode_labels) cutmix_mixup = T(num_classes=num_classes) @@ -2198,7 +2201,10 @@ def test_supported_input_structure(self, T): img, target = next(iter(dl)) input_img_size = img.shape[-3:] assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) - assert target.shape == (batch_size,) + if encode_labels: + assert target.shape == (batch_size, num_classes) + else: + assert target.shape == (batch_size,) def check_output(img, target): assert img.shape == (batch_size, *input_img_size) @@ -2209,7 +2215,10 @@ def check_output(img, target): # After Dataloader, as unpacked input img, target = next(iter(dl)) - assert target.shape == (batch_size,) + if encode_labels: + assert target.shape == (batch_size, num_classes) + else: + assert target.shape == (batch_size,) img, target = cutmix_mixup(img, target) check_output(img, target) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index cc645d6c8a8..4c2102c4639 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -162,9 +162,9 @@ def forward(self, *inputs): labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") - elif labels.ndim != 1: + elif not 0 < labels.ndim <= 2 or (labels.ndim == 2 and labels.shape[1] != self.num_classes): raise ValueError( - f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." + f"labels tensor should be of shape (batch_size,) or (batch_size,num_classes) " f"but got shape {labels.shape} instead." ) params = { @@ -198,7 +198,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): ) def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: - label = one_hot(label, num_classes=self.num_classes) + if label.ndim == 1: + label = one_hot(label, num_classes=self.num_classes) if not label.dtype.is_floating_point: label = label.float() return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) From e31e8f28f8df69b57e1d6db233eb7b27d459b809 Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Fri, 24 May 2024 15:46:20 +0100 Subject: [PATCH 3/4] Apply NH diff --- test/test_transforms_v2.py | 48 ++++++++++++--------------- torchvision/transforms/v2/_augment.py | 26 +++++++++++---- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 190b590c89f..07235333af4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2169,29 +2169,30 @@ def test_image_correctness(self, brightness_factor): class TestCutMixMixUp: class DummyDataset: - def __init__(self, size, num_classes, encode_labels:bool): + def __init__(self, size, num_classes, one_hot_labels): self.size = size self.num_classes = num_classes - self.encode_labels = encode_labels + self.one_hot_labels = one_hot_labels assert size < num_classes def __getitem__(self, idx): img = torch.rand(3, 100, 100) - label = torch.tensor(idx) # This ensures all labels in a batch are unique and makes testing easier - if self.encode_labels: - label = torch.nn.functional.one_hot(label, num_classes=self.num_classes) + label = idx # This ensures all labels in a batch are unique and makes testing easier + if self.one_hot_labels: + label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes) return img, label def __len__(self): return self.size - @pytest.mark.parametrize(["T", "encode_labels"], [[transforms.CutMix, False], [transforms.MixUp, False], [transforms.CutMix, True], [transforms.MixUp, True]]) - def test_supported_input_structure(self, T, encode_labels: bool): + @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) + @pytest.mark.parametrize("one_hot_labels", (True, False)) + def test_supported_input_structure(self, T, one_hot_labels): batch_size = 32 num_classes = 100 - dataset = self.DummyDataset(size=batch_size, num_classes=num_classes,encode_labels=encode_labels) + dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels) cutmix_mixup = T(num_classes=num_classes) @@ -2201,10 +2202,7 @@ def test_supported_input_structure(self, T, encode_labels: bool): img, target = next(iter(dl)) input_img_size = img.shape[-3:] assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) - if encode_labels: - assert target.shape == (batch_size, num_classes) - else: - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) def check_output(img, target): assert img.shape == (batch_size, *input_img_size) @@ -2215,10 +2213,7 @@ def check_output(img, target): # After Dataloader, as unpacked input img, target = next(iter(dl)) - if encode_labels: - assert target.shape == (batch_size, num_classes) - else: - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) img, target = cutmix_mixup(img, target) check_output(img, target) @@ -2273,7 +2268,7 @@ def test_error(self, T): with pytest.raises(ValueError, match="Could not infer where the labels are"): cutmix_mixup({"img": imgs, "Nothing_else": 3}) - with pytest.raises(ValueError, match="labels tensor should be of shape"): + with pytest.raises(ValueError, match="labels should be index based"): # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently cutmix_mixup(imgs) @@ -2281,22 +2276,21 @@ def test_error(self, T): with pytest.raises(ValueError, match="When using the default labels_getter"): cutmix_mixup(imgs, "not_a_tensor") - with pytest.raises(ValueError, match="labels tensor should be of shape"): - cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3))) - with pytest.raises(ValueError, match="Expected a batched input with 4 dims"): cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,))) with pytest.raises(ValueError, match="does not match the batch size of the labels"): cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,))) - with pytest.raises(ValueError, match="labels tensor should be of shape"): - # The purpose of this check is more about documenting the current - # behaviour of what happens on a Compose(), rather than actually - # asserting the expected behaviour. We may support Compose() in the - # future, e.g. for 2 consecutive CutMix? - labels = torch.randint(0, num_classes, size=(batch_size,)) - transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels) + with pytest.raises(ValueError, match="When passing 2D labels"): + wrong_num_classes = num_classes + 1 + T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes))) + + with pytest.raises(ValueError, match="but got a tensor of shape"): + cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4))) + + with pytest.raises(ValueError, match="num_classes must be passed"): + T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,))) @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 48daa271ea4..1d010126545 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class _BaseMixUpCutMix(Transform): - def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", labels_encoded: bool = False) -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) @@ -150,7 +150,6 @@ def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="defau self.num_classes = num_classes self._labels_getter = _parse_labels_getter(labels_getter) - self._labels_encoded = labels_encoded def forward(self, *inputs): inputs = inputs if len(inputs) > 1 else inputs[0] @@ -163,10 +162,21 @@ def forward(self, *inputs): labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") - elif not 0 < labels.ndim <= 2 or (labels.ndim == 2 and labels.shape[1] != self.num_classes): + if labels.ndim not in (1, 2): raise ValueError( - f"labels tensor should be of shape (batch_size,) or (batch_size,num_classes) " f"but got shape {labels.shape} instead." + f"labels should be index based with shape (batch_size,) " + f"or probability based with shape (batch_size, num_classes), " + f"but got a tensor of shape {labels.shape} instead." ) + if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes: + raise ValueError( + f"When passing 2D labels, " + f"the number of elements in last dimension must match num_classes: " + f"{labels.shape[-1]} != {self.num_classes}. " + f"You can Leave num_classes to None." + ) + if labels.ndim == 1 and self.num_classes is None: + raise ValueError("num_classes must be passed if the labels are index-based (1D)") params = { "labels": labels, @@ -225,7 +235,8 @@ class MixUp(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. @@ -273,7 +284,8 @@ class CutMix(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. From f1aa311960e0a5f51076f76574aa57893d676486 Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Tue, 28 May 2024 11:34:58 +0100 Subject: [PATCH 4/4] Try appeasing mypy --- torchvision/transforms/v2/_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 1d010126545..f085ef3ca6e 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -210,7 +210,7 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: if label.ndim == 1: - label = one_hot(label, num_classes=self.num_classes) + label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type] if not label.dtype.is_floating_point: label = label.float() return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))