From 2de7cfce9b5f48d4f2e6a4b17f0c6cd377009783 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Sep 2021 18:15:00 +0100 Subject: [PATCH 01/16] Add RandomMixupCutmix. --- test/test_transforms_tensor.py | 43 +++++++++ torchvision/transforms/transforms.py | 131 ++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index aaf7880f124..e9484cc88f9 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -715,3 +715,46 @@ def test_gaussian_blur(device, meth_kwargs): T.GaussianBlur, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, agg_method="max", tol=tol ) + + +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('alphas', [ + {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 1.0}, + {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 0.0}, + {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'p': 0.0}, + {"mixup_alpha": 0.0, "cutmix_alpha": 1.0}, + {"mixup_alpha": 1.0, "cutmix_alpha": 0.0}, +]) +@pytest.mark.parametrize('label_smoothing', [0.0, 0.1]) +@pytest.mark.parametrize('inplace', [True, False]) +def test_random_mixupcutmix(device, alphas, label_smoothing, inplace): + batch_size = 4 + num_classes = 10 + batch = torch.rand(batch_size, 3, 44, 56, device=device) + targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64) + + trans = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas) + + original_shape = batch.shape + batch, targets = trans(batch, targets) + assert batch.shape == original_shape + assert targets.shape == (batch_size, num_classes) + + trans.__repr__() + + +def test_random_mixupcutmix_with_invalid_data(): + with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."): + T.RandomMixupCutmix(0) + with pytest.raises(AssertionError, match="Both alpha params can't be zero."): + T.RandomMixupCutmix(10, mixup_alpha=0.0, cutmix_alpha=0.0) + + t = T.RandomMixupCutmix(10) + with pytest.raises(ValueError, match="Batch ndim should be 4."): + t(torch.rand(3, 60, 60), torch.randint(10, (1, ))) + with pytest.raises(ValueError, match="Target ndim should be 1."): + t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1))) + with pytest.raises(ValueError, match="Target dtype should be torch.int64."): + t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) + with pytest.raises(ValueError, match="The batch size should be even."): + t(torch.rand(31, 3, 60, 60), torch.randint(10, (31, ))) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 4b3c08dbce7..55b89ad2112 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixupCutmix'] class Compose: @@ -1953,3 +1953,132 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomMixupCutmix(torch.nn.Module): + """Randomly apply Mixum or Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the papers + `"mixup: Beyond Empirical Risk Minimization" `_ and + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 1.0. + mixup_alpha (float): hyperparameter of the Beta distribution used for mixup. + Set to 0.0 to turn off. Default value is 1.0. + cutmix_p (float): probability of using cutmix instead of mixup when both are on. + Default value is 0.5. + cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix. + Set to 0.0 to turn off. Default value is 0.0. + label_smoothing (float): the amount of smoothing using when one-hot encoding. + Set to 0.0 to turn off. Default value is 0.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 1.0, mixup_alpha: float = 1.0, + cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, + label_smoothing: float = 0.0, inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero." + + self.num_classes = num_classes + self.p = p + self.mixup_alpha = mixup_alpha + self.cutmix_p = cutmix_p + self.cutmix_alpha = cutmix_alpha + self.label_smoothing = label_smoothing + self.inplace = inplace + + # Torch.distributions.* are not JIT scriptable. see https://github.com/pytorch/pytorch/issues/29843 + self._mixup_dist = torch.distributions.Beta(self.mixup_alpha, + self.mixup_alpha) if self.mixup_alpha > 0 else None + self._cutmix_dist = torch.distributions.Beta(self.cutmix_alpha, + self.cutmix_alpha) if self.cutmix_alpha > 0 else None + + def _smooth_one_hot(self, target: Tensor) -> Tensor: + N = target.shape[0] + device = target.device + v = torch.full(size=(N, 1), fill_value=1 - self.label_smoothing, device=device) + return torch.full(size=(N, self.num_classes), fill_value=self.label_smoothing / self.num_classes, + device=device).scatter_add_(1, target.unsqueeze(1), v) + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif target.dtype != torch.int64: + raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + elif batch.size(0) % 2 != 0: + # speed optimization, see below + raise ValueError("The batch size should be even.") + + if not self.inplace: + batch = batch.clone() + # target = target.clone() + + target = self._smooth_one_hot(target) + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to flip the batch instead of shuffling it to create image pairs + batch_flipped = batch.flip(0) + target_flipped = target.flip(0) + + if self._mixup_dist is None: + use_mixup = False + else: + use_mixup = self._cutmix_dist is None or torch.rand(1).item() >= self.cutmix_p + + if use_mixup: + # Implemented as on mixup paper, page 3. + lambda_param = self._mixup_dist.sample() + batch_flipped.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_flipped) + else: + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = self._cutmix_dist.sample() + W, H = F.get_image_size(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = torch.clamp(r_x - r_w_half, min=0) + y1 = torch.clamp(r_y - r_h_half, min=0) + x2 = torch.clamp(r_x + r_w_half, max=W) + y2 = torch.clamp(r_y + r_h_half, max=H) + + batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] + lambda_param = 1.0 - (x2 - x1) * (y2 - y1) / (W * H) + + target_flipped.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_flipped) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', mixup_alpha={mixup_alpha}' + s += ', cutmix_p={cutmix_p}' + s += ', cutmix_alpha={cutmix_alpha}' + s += ', label_smoothing={label_smoothing}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) From 55aedd3c038409174931b5f1be61e67598cab73d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Sep 2021 19:46:00 +0100 Subject: [PATCH 02/16] Add test with real data. --- test/test_transforms_tensor.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index e9484cc88f9..323ad00944c 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,6 +1,8 @@ import os import torch +from torch._utils_internal import get_file_path_2 from torchvision import transforms as T +from torchvision.io import read_image from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode @@ -758,3 +760,28 @@ def test_random_mixupcutmix_with_invalid_data(): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) with pytest.raises(ValueError, match="The batch size should be even."): t(torch.rand(31, 3, 60, 60), torch.randint(10, (31, ))) + + +def test_random_mixupcutmix_with_real_data(): + torch.manual_seed(112) + + resize = T.Resize((224, 224)) + mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1) + + images = [] + for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]: + fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file + img = read_image(get_file_path_2(*fullpath)) + images.append(resize(img)) + + batch = torch.stack(images).to(torch.float32) + targets = torch.tensor([0, 1]) + + stats = [] + for _ in range(25): + b, t = mixup(batch, targets) + stats.append([b.mean().item(), b.std().item(), t.mean().item(), t.std().item()]) + + torch.testing.assert_close( + torch.tensor(stats).mean(dim=0), + torch.tensor([46.9443, 58.3993, 0.5000, 0.1987]), rtol=0.0, atol=1e-4) From 33d9575eb5efca771f23d7ff539c0c974daaa4de Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Sep 2021 22:13:43 +0100 Subject: [PATCH 03/16] Use dataloader and collate in the test. --- test/test_transforms_tensor.py | 28 ++++++++++++++++------------ torchvision/transforms/transforms.py | 6 +++--- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 323ad00944c..3dbda413b5c 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,6 +1,7 @@ import os import torch from torch._utils_internal import get_file_path_2 +from torch.utils.data import TensorDataset, DataLoader from torchvision import transforms as T from torchvision.io import read_image from torchvision.transforms import functional as F @@ -763,25 +764,28 @@ def test_random_mixupcutmix_with_invalid_data(): def test_random_mixupcutmix_with_real_data(): - torch.manual_seed(112) - - resize = T.Resize((224, 224)) - mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1) + torch.manual_seed(12) + # Build dummy dataset images = [] for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]: fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file img = read_image(get_file_path_2(*fullpath)) - images.append(resize(img)) + images.append(F.resize(img, [224, 224])) + dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1])) - batch = torch.stack(images).to(torch.float32) - targets = torch.tensor([0, 1]) + # Use mixup in the collate + mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1) + dataloader = DataLoader(dataset, batch_size=2, + collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch)))) + # Test against known statistics about the produced images stats = [] for _ in range(25): - b, t = mixup(batch, targets) - stats.append([b.mean().item(), b.std().item(), t.mean().item(), t.std().item()]) - + for b, t in dataloader: + stats.append(torch.stack([b.mean(), b.std(), t.std()])) + torch.testing.assert_close( - torch.tensor(stats).mean(dim=0), - torch.tensor([46.9443, 58.3993, 0.5000, 0.1987]), rtol=0.0, atol=1e-4) + torch.stack(stats).mean(dim=0), + torch.tensor([46.94434738, 64.79092407, 0.23949696]) + ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 55b89ad2112..f762b7a815a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -2043,12 +2043,12 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if use_mixup: # Implemented as on mixup paper, page 3. - lambda_param = self._mixup_dist.sample() + lambda_param = self._mixup_dist.sample().item() batch_flipped.mul_(1.0 - lambda_param) batch.mul_(lambda_param).add_(batch_flipped) else: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = self._cutmix_dist.sample() + lambda_param = self._cutmix_dist.sample().item() W, H = F.get_image_size(batch) r_x = torch.randint(W, (1,)) @@ -2064,7 +2064,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: y2 = torch.clamp(r_y + r_h_half, max=H) batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] - lambda_param = 1.0 - (x2 - x1) * (y2 - y1) / (W * H) + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) target_flipped.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_flipped) From c15dce0d1cbc73991f9b46a0e626bfeaf87d2c02 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 8 Sep 2021 12:06:01 +0100 Subject: [PATCH 04/16] Making RandomMixupCutmix JIT scriptable. --- test/test_transforms_tensor.py | 18 +++++++++++------- torchvision/transforms/transforms.py | 22 ++++++++-------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 3dbda413b5c..76f838f5c7b 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -731,19 +731,23 @@ def test_gaussian_blur(device, meth_kwargs): @pytest.mark.parametrize('label_smoothing', [0.0, 0.1]) @pytest.mark.parametrize('inplace', [True, False]) def test_random_mixupcutmix(device, alphas, label_smoothing, inplace): - batch_size = 4 + batch_size = 32 num_classes = 10 batch = torch.rand(batch_size, 3, 44, 56, device=device) targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64) - trans = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas) + fn = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas) + scripted_fn = torch.jit.script(fn) + + seed = torch.seed() + output = fn(batch.clone(), targets.clone()) - original_shape = batch.shape - batch, targets = trans(batch, targets) - assert batch.shape == original_shape - assert targets.shape == (batch_size, num_classes) + torch.manual_seed(seed) + output_scripted = scripted_fn(batch.clone(), targets.clone()) + assert_equal(output[0], output_scripted[0]) + assert_equal(output[1], output_scripted[1]) - trans.__repr__() + fn.__repr__() def test_random_mixupcutmix_with_invalid_data(): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f762b7a815a..63fc668d90d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1992,12 +1992,6 @@ def __init__(self, num_classes: int, self.label_smoothing = label_smoothing self.inplace = inplace - # Torch.distributions.* are not JIT scriptable. see https://github.com/pytorch/pytorch/issues/29843 - self._mixup_dist = torch.distributions.Beta(self.mixup_alpha, - self.mixup_alpha) if self.mixup_alpha > 0 else None - self._cutmix_dist = torch.distributions.Beta(self.cutmix_alpha, - self.cutmix_alpha) if self.cutmix_alpha > 0 else None - def _smooth_one_hot(self, target: Tensor) -> Tensor: N = target.shape[0] device = target.device @@ -2036,19 +2030,19 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: batch_flipped = batch.flip(0) target_flipped = target.flip(0) - if self._mixup_dist is None: + if self.mixup_alpha <= 0.0: use_mixup = False else: - use_mixup = self._cutmix_dist is None or torch.rand(1).item() >= self.cutmix_p + use_mixup = self.cutmix_alpha <= 0.0 or torch.rand(1).item() >= self.cutmix_p if use_mixup: # Implemented as on mixup paper, page 3. - lambda_param = self._mixup_dist.sample().item() + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0]) batch_flipped.mul_(1.0 - lambda_param) batch.mul_(lambda_param).add_(batch_flipped) else: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = self._cutmix_dist.sample().item() + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0]) W, H = F.get_image_size(batch) r_x = torch.randint(W, (1,)) @@ -2058,10 +2052,10 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: r_w_half = int(r * W) r_h_half = int(r * H) - x1 = torch.clamp(r_x - r_w_half, min=0) - y1 = torch.clamp(r_y - r_h_half, min=0) - x2 = torch.clamp(r_x + r_w_half, max=W) - y2 = torch.clamp(r_y + r_h_half, max=H) + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) From 6f2ebeaadf6fb5a695ad33edb20cc5f06a8cde55 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 9 Sep 2021 18:32:45 +0100 Subject: [PATCH 05/16] Move out label_smoothing and try roll instead of flip --- test/test_transforms_tensor.py | 11 ++++------- torchvision/transforms/transforms.py | 24 +++++------------------- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 76f838f5c7b..744b94cfddb 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -728,15 +728,14 @@ def test_gaussian_blur(device, meth_kwargs): {"mixup_alpha": 0.0, "cutmix_alpha": 1.0}, {"mixup_alpha": 1.0, "cutmix_alpha": 0.0}, ]) -@pytest.mark.parametrize('label_smoothing', [0.0, 0.1]) @pytest.mark.parametrize('inplace', [True, False]) -def test_random_mixupcutmix(device, alphas, label_smoothing, inplace): +def test_random_mixupcutmix(device, alphas, inplace): batch_size = 32 num_classes = 10 batch = torch.rand(batch_size, 3, 44, 56, device=device) targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64) - fn = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas) + fn = T.RandomMixupCutmix(num_classes, inplace=inplace, **alphas) scripted_fn = torch.jit.script(fn) seed = torch.seed() @@ -763,8 +762,6 @@ def test_random_mixupcutmix_with_invalid_data(): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1))) with pytest.raises(ValueError, match="Target dtype should be torch.int64."): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) - with pytest.raises(ValueError, match="The batch size should be even."): - t(torch.rand(31, 3, 60, 60), torch.randint(10, (31, ))) def test_random_mixupcutmix_with_real_data(): @@ -779,7 +776,7 @@ def test_random_mixupcutmix_with_real_data(): dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1])) # Use mixup in the collate - mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1) + mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0) dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch)))) @@ -791,5 +788,5 @@ def test_random_mixupcutmix_with_real_data(): torch.testing.assert_close( torch.stack(stats).mean(dim=0), - torch.tensor([46.94434738, 64.79092407, 0.23949696]) + torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751]) ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 63fc668d90d..87994a40f3f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1971,15 +1971,13 @@ class RandomMixupCutmix(torch.nn.Module): Default value is 0.5. cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix. Set to 0.0 to turn off. Default value is 0.0. - label_smoothing (float): the amount of smoothing using when one-hot encoding. - Set to 0.0 to turn off. Default value is 0.0. inplace (bool): boolean to make this transform inplace. Default set to False. """ def __init__(self, num_classes: int, p: float = 1.0, mixup_alpha: float = 1.0, cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, - label_smoothing: float = 0.0, inplace: bool = False) -> None: + inplace: bool = False) -> None: super().__init__() assert num_classes > 0, "Please provide a valid positive value for the num_classes." assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero." @@ -1989,16 +1987,8 @@ def __init__(self, num_classes: int, self.mixup_alpha = mixup_alpha self.cutmix_p = cutmix_p self.cutmix_alpha = cutmix_alpha - self.label_smoothing = label_smoothing self.inplace = inplace - def _smooth_one_hot(self, target: Tensor) -> Tensor: - N = target.shape[0] - device = target.device - v = torch.full(size=(N, 1), fill_value=1 - self.label_smoothing, device=device) - return torch.full(size=(N, self.num_classes), fill_value=self.label_smoothing / self.num_classes, - device=device).scatter_add_(1, target.unsqueeze(1), v) - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """ Args: @@ -2014,21 +2004,18 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) elif target.dtype != torch.int64: raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) - elif batch.size(0) % 2 != 0: - # speed optimization, see below - raise ValueError("The batch size should be even.") if not self.inplace: batch = batch.clone() # target = target.clone() - target = self._smooth_one_hot(target) + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) if torch.rand(1).item() >= self.p: return batch, target - # It's faster to flip the batch instead of shuffling it to create image pairs - batch_flipped = batch.flip(0) - target_flipped = target.flip(0) + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_flipped = batch.roll(1) + target_flipped = target.roll(1) if self.mixup_alpha <= 0.0: use_mixup = False @@ -2072,7 +2059,6 @@ def __repr__(self) -> str: s += ', mixup_alpha={mixup_alpha}' s += ', cutmix_p={cutmix_p}' s += ', cutmix_alpha={cutmix_alpha}' - s += ', label_smoothing={label_smoothing}' s += ', inplace={inplace}' s += ')' return s.format(**self.__dict__) From 67acd892cd54b5b05674314c5f6a9d9f0bd35c96 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 9 Sep 2021 20:23:37 +0100 Subject: [PATCH 06/16] Adding mixup/cutmix in references script. --- references/classification/train.py | 12 ++++++++++-- test/test_transforms_tensor.py | 2 +- torchvision/transforms/transforms.py | 17 +++++++++-------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 89eae31c2cd..2d4a7c9a6fc 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -165,10 +165,16 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + + collate_fn = None + if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0: + mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha, + cutmix_alpha=args.cutmix_alpha) + collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) - + sampler=train_sampler, num_workers=args.workers, pin_memory=True, + collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) @@ -273,6 +279,8 @@ def get_args_parser(add_help=True): parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') + parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') + parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 744b94cfddb..058ee1a0cfc 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -788,5 +788,5 @@ def test_random_mixupcutmix_with_real_data(): torch.testing.assert_close( torch.stack(stats).mean(dim=0), - torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751]) + torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751]) ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 87994a40f3f..6c7e47c221f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1955,8 +1955,9 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) +# TODO: move this to references before merging and delete the tests class RandomMixupCutmix(torch.nn.Module): - """Randomly apply Mixum or Cutmix to the provided batch and targets. + """Randomly apply Mixup or Cutmix to the provided batch and targets. The class implements the data augmentations as described in the papers `"mixup: Beyond Empirical Risk Minimization" `_ and `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" @@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_flipped = batch.roll(1) - target_flipped = target.roll(1) + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1) if self.mixup_alpha <= 0.0: use_mixup = False @@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if use_mixup: # Implemented as on mixup paper, page 3. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0]) - batch_flipped.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_flipped) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) else: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0]) @@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: x2 = int(torch.clamp(r_x + r_w_half, max=W)) y2 = int(torch.clamp(r_y + r_h_half, max=H)) - batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - target_flipped.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_flipped) + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) return batch, target From 544967e0f5bb0a2b62f2e6a0b36bf4a6206aa0ae Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 9 Sep 2021 21:57:52 +0100 Subject: [PATCH 07/16] Handle one-hot encoded target in accuracy. --- references/classification/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/references/classification/utils.py b/references/classification/utils.py index 644f1c4708a..53a691de10a 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -178,6 +178,8 @@ def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() From 0e128f3d5b0c207a1f72f2dea15824eab8f7a79b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 13 Sep 2021 16:32:56 +0100 Subject: [PATCH 08/16] Add support of devices on tests. --- references/classification/train.py | 3 ++- test/test_transforms_tensor.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 2d4a7c9a6fc..a3bd744f6a6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -4,6 +4,7 @@ import torch import torch.utils.data +from torch.utils.data.dataloader import default_collate from torch import nn import torchvision from torchvision.transforms.functional import InterpolationMode @@ -170,7 +171,7 @@ def main(args): if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0: mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha) - collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731 + collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 058ee1a0cfc..a87a0ca543a 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -764,7 +764,8 @@ def test_random_mixupcutmix_with_invalid_data(): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) -def test_random_mixupcutmix_with_real_data(): +@pytest.mark.parametrize('device', cpu_and_gpu()) +def test_random_mixupcutmix_with_real_data(device): torch.manual_seed(12) # Build dummy dataset @@ -773,7 +774,8 @@ def test_random_mixupcutmix_with_real_data(): fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file img = read_image(get_file_path_2(*fullpath)) images.append(F.resize(img, [224, 224])) - dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1])) + dataset = TensorDataset(torch.stack(images).to(device=device, dtype=torch.float32), + torch.tensor([0, 1], device=device)) # Use mixup in the collate mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0) From 3f199027e6471ec6e5c2021c341773906a0360f4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 13 Sep 2021 18:29:32 +0100 Subject: [PATCH 09/16] Separate Mixup from Cutmix. --- references/classification/train.py | 11 +- test/test_transforms.py | 3 +- test/test_transforms_tensor.py | 40 +++---- torchvision/transforms/transforms.py | 162 +++++++++++++++++++-------- 4 files changed, 143 insertions(+), 73 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index a3bd744f6a6..565cebe9797 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -168,9 +168,14 @@ def main(args): dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) collate_fn = None - if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0: - mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha, - cutmix_alpha=args.cutmix_alpha) + num_classes = len(dataset.classes) + mixup_transforms = [] + if args.mixup_alpha > 0.0: + mixup_transforms.append(torchvision.transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + if args.cutmix_alpha > 0.0: + mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) + if mixup_transforms: + mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms, p=[0.5, 0.5]) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, diff --git a/test/test_transforms.py b/test/test_transforms.py index 675e79ac3ba..541b0adfb6c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1311,7 +1311,8 @@ def test_random_choice(): transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10) - ] + ], + [1 / 3, 1 / 3, 1 / 3] ) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index a87a0ca543a..131f2e2817a 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -2,6 +2,7 @@ import torch from torch._utils_internal import get_file_path_2 from torch.utils.data import TensorDataset, DataLoader +from torch.utils.data.dataloader import default_collate from torchvision import transforms as T from torchvision.io import read_image from torchvision.transforms import functional as F @@ -721,21 +722,16 @@ def test_gaussian_blur(device, meth_kwargs): @pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('alphas', [ - {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 1.0}, - {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 0.0}, - {"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'p': 0.0}, - {"mixup_alpha": 0.0, "cutmix_alpha": 1.0}, - {"mixup_alpha": 1.0, "cutmix_alpha": 0.0}, -]) +@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix]) +@pytest.mark.parametrize('p', [0.0, 1.0]) @pytest.mark.parametrize('inplace', [True, False]) -def test_random_mixupcutmix(device, alphas, inplace): +def test_random_mixupcutmix(device, tranform, p, inplace): batch_size = 32 num_classes = 10 batch = torch.rand(batch_size, 3, 44, 56, device=device) targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64) - fn = T.RandomMixupCutmix(num_classes, inplace=inplace, **alphas) + fn = tranform(num_classes, p=p, inplace=inplace) scripted_fn = torch.jit.script(fn) seed = torch.seed() @@ -749,13 +745,14 @@ def test_random_mixupcutmix(device, alphas, inplace): fn.__repr__() -def test_random_mixupcutmix_with_invalid_data(): +@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix]) +def test_random_mixupcutmix_with_invalid_data(tranform): with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."): - T.RandomMixupCutmix(0) - with pytest.raises(AssertionError, match="Both alpha params can't be zero."): - T.RandomMixupCutmix(10, mixup_alpha=0.0, cutmix_alpha=0.0) + tranform(0) + with pytest.raises(AssertionError, match="Alpha param can't be zero."): + tranform(10, alpha=0.0) - t = T.RandomMixupCutmix(10) + t = tranform(10) with pytest.raises(ValueError, match="Batch ndim should be 4."): t(torch.rand(3, 60, 60), torch.randint(10, (1, ))) with pytest.raises(ValueError, match="Target ndim should be 1."): @@ -765,7 +762,11 @@ def test_random_mixupcutmix_with_invalid_data(): @pytest.mark.parametrize('device', cpu_and_gpu()) -def test_random_mixupcutmix_with_real_data(device): +@pytest.mark.parametrize('transform, expected', [ + (T.RandomMixup, [60.77401351928711, 0.5151033997535706]), + (T.RandomCutmix, [70.13909912109375, 0.525851309299469]) +]) +def test_random_mixupcutmix_with_real_data(device, transform, expected): torch.manual_seed(12) # Build dummy dataset @@ -778,17 +779,16 @@ def test_random_mixupcutmix_with_real_data(device): torch.tensor([0, 1], device=device)) # Use mixup in the collate - mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0) - dataloader = DataLoader(dataset, batch_size=2, - collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch)))) + trans = transform(2) + dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: trans(*default_collate(batch))) # Test against known statistics about the produced images stats = [] for _ in range(25): for b, t in dataloader: - stats.append(torch.stack([b.mean(), b.std(), t.std()])) + stats.append(torch.stack([b.std(), t.std()])) torch.testing.assert_close( torch.stack(stats).mean(dim=0), - torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751]) + torch.tensor(expected) ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6c7e47c221f..965ee78bf0c 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,8 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixupCutmix'] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixup', + "RandomCutmix"] class Compose: @@ -515,9 +516,20 @@ def __call__(self, img): class RandomChoice(RandomTransforms): """Apply single transformation randomly picked from a list. This transform does not support torchscript. """ - def __call__(self, img): - t = random.choice(self.transforms) - return t(img) + def __init__(self, transforms, p=None): + super().__init__(transforms) + if p is not None and not isinstance(p, Sequence): + raise TypeError("Argument transforms should be a sequence") + self.p = p + + def __call__(self, *args): + t = random.choices(self.transforms, weights=self.p)[0] + return t(*args) + + def __repr__(self): + format_string = super().__repr__() + format_string += '(p={0})'.format(self.p) + return format_string class RandomCrop(torch.nn.Module): @@ -1956,38 +1968,103 @@ def __repr__(self): # TODO: move this to references before merging and delete the tests -class RandomMixupCutmix(torch.nn.Module): - """Randomly apply Mixup or Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the papers - `"mixup: Beyond Empirical Risk Minimization" `_ and +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 0.5, alpha: float = 1.0, + inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif target.dtype != torch.int64: + raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + + if not self.inplace: + batch = batch.clone() + # target = target.clone() + + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1) + + # Implemented as on mixup paper, page 3. + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', alpha={alpha}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. Args: num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 1.0. - mixup_alpha (float): hyperparameter of the Beta distribution used for mixup. - Set to 0.0 to turn off. Default value is 1.0. - cutmix_p (float): probability of using cutmix instead of mixup when both are on. - Default value is 0.5. - cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix. - Set to 0.0 to turn off. Default value is 0.0. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. inplace (bool): boolean to make this transform inplace. Default set to False. """ def __init__(self, num_classes: int, - p: float = 1.0, mixup_alpha: float = 1.0, - cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, + p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() assert num_classes > 0, "Please provide a valid positive value for the num_classes." - assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero." + assert alpha > 0, "Alpha param can't be zero." self.num_classes = num_classes self.p = p - self.mixup_alpha = mixup_alpha - self.cutmix_p = cutmix_p - self.cutmix_alpha = cutmix_alpha + self.alpha = alpha self.inplace = inplace def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: @@ -2018,35 +2095,24 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1) - if self.mixup_alpha <= 0.0: - use_mixup = False - else: - use_mixup = self.cutmix_alpha <= 0.0 or torch.rand(1).item() >= self.cutmix_p - - if use_mixup: - # Implemented as on mixup paper, page 3. - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0]) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - else: - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0]) - W, H = F.get_image_size(batch) + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + W, H = F.get_image_size(batch) - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) target_rolled.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_rolled) @@ -2057,9 +2123,7 @@ def __repr__(self) -> str: s = self.__class__.__name__ + '(' s += 'num_classes={num_classes}' s += ', p={p}' - s += ', mixup_alpha={mixup_alpha}' - s += ', cutmix_p={cutmix_p}' - s += ', cutmix_alpha={cutmix_alpha}' + s += ', alpha={alpha}' s += ', inplace={inplace}' s += ')' return s.format(**self.__dict__) From 78e8605e72d6d8fa76946cab258549f0d2ebea2d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 13 Sep 2021 18:54:48 +0100 Subject: [PATCH 10/16] Add check for floats. --- test/test_transforms_tensor.py | 4 +++- torchvision/transforms/transforms.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 131f2e2817a..41a58b4e7ac 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -757,7 +757,9 @@ def test_random_mixupcutmix_with_invalid_data(tranform): t(torch.rand(3, 60, 60), torch.randint(10, (1, ))) with pytest.raises(ValueError, match="Target ndim should be 1."): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1))) - with pytest.raises(ValueError, match="Target dtype should be torch.int64."): + with pytest.raises(TypeError, match="Batch dtype should be a float tensor."): + t(torch.randint(256, (32, 3, 60, 60), dtype=torch.uint8), torch.randint(10, (32, ))) + with pytest.raises(TypeError, match="Target dtype should be torch.int64."): t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 965ee78bf0c..b479fa99abe 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -2006,8 +2006,10 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) elif target.ndim != 1: raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) elif target.dtype != torch.int64: - raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) if not self.inplace: batch = batch.clone() @@ -2080,8 +2082,10 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) elif target.ndim != 1: raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) elif target.dtype != torch.int64: - raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) if not self.inplace: batch = batch.clone() From de9fa07b568172f98acb4ecfdfcac385f70f13fc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 13 Sep 2021 19:21:47 +0100 Subject: [PATCH 11/16] Adding device on expect value. --- test/test_transforms_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 41a58b4e7ac..e8a9a920433 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -792,5 +792,5 @@ def test_random_mixupcutmix_with_real_data(device, transform, expected): torch.testing.assert_close( torch.stack(stats).mean(dim=0), - torch.tensor(expected) + torch.tensor(expected, device=device) ) From 3f212fe24da2e6bc4163457957107e30c445716d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 14 Sep 2021 10:02:41 +0100 Subject: [PATCH 12/16] Remove hardcoded weights. --- references/classification/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 565cebe9797..d4ac1aa68b3 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -175,7 +175,7 @@ def main(args): if args.cutmix_alpha > 0.0: mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) if mixup_transforms: - mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms, p=[0.5, 0.5]) + mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, From 33c29730347e12a41d96eb7d2e7dabea891cfd69 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 15 Sep 2021 13:58:11 +0100 Subject: [PATCH 13/16] One-hot only when necessary. --- torchvision/transforms/transforms.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index b479fa99abe..02956346703 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixup', + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", "RandomMixup", "RandomCutmix"] @@ -2013,9 +2013,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if not self.inplace: batch = batch.clone() - # target = target.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) if torch.rand(1).item() >= self.p: return batch, target @@ -2089,9 +2091,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if not self.inplace: batch = batch.clone() - # target = target.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) if torch.rand(1).item() >= self.p: return batch, target From 9abb18b547ab48ff70c47ae3b64a3a6bd66b1511 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 15 Sep 2021 14:02:55 +0100 Subject: [PATCH 14/16] Fix linter. --- torchvision/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 02956346703..a2439f7000e 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -2016,7 +2016,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: target = target.clone() if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) if torch.rand(1).item() >= self.p: return batch, target From eb932b956dde3c028f7785faf04c460e163b7f40 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 15 Sep 2021 15:17:31 +0100 Subject: [PATCH 15/16] Moving mixup and cutmix to references. --- references/classification/train.py | 5 +- references/classification/transforms.py | 175 ++++++++++++++++++++++++ test/test_transforms_tensor.py | 79 ----------- torchvision/transforms/transforms.py | 173 +---------------------- 4 files changed, 179 insertions(+), 253 deletions(-) create mode 100644 references/classification/transforms.py diff --git a/references/classification/train.py b/references/classification/train.py index d71ce36a510..3ec9039a018 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -10,6 +10,7 @@ from torchvision.transforms.functional import InterpolationMode import presets +import transforms import utils try: @@ -170,9 +171,9 @@ def main(args): num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: - mixup_transforms.append(torchvision.transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) if args.cutmix_alpha > 0.0: - mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) + mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 diff --git a/references/classification/transforms.py b/references/classification/transforms.py new file mode 100644 index 00000000000..b3d6f9474ce --- /dev/null +++ b/references/classification/transforms.py @@ -0,0 +1,175 @@ +import math +import torch + +from typing import Tuple +from torch import Tensor +from torchvision.transforms import functional as F + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 0.5, alpha: float = 1.0, + inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + elif target.dtype != torch.int64: + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1) + + # Implemented as on mixup paper, page 3. + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', alpha={alpha}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 0.5, alpha: float = 1.0, + inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + elif target.dtype != torch.int64: + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + W, H = F.get_image_size(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', alpha={alpha}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index e8a9a920433..aaf7880f124 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,10 +1,6 @@ import os import torch -from torch._utils_internal import get_file_path_2 -from torch.utils.data import TensorDataset, DataLoader -from torch.utils.data.dataloader import default_collate from torchvision import transforms as T -from torchvision.io import read_image from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode @@ -719,78 +715,3 @@ def test_gaussian_blur(device, meth_kwargs): T.GaussianBlur, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, agg_method="max", tol=tol ) - - -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix]) -@pytest.mark.parametrize('p', [0.0, 1.0]) -@pytest.mark.parametrize('inplace', [True, False]) -def test_random_mixupcutmix(device, tranform, p, inplace): - batch_size = 32 - num_classes = 10 - batch = torch.rand(batch_size, 3, 44, 56, device=device) - targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64) - - fn = tranform(num_classes, p=p, inplace=inplace) - scripted_fn = torch.jit.script(fn) - - seed = torch.seed() - output = fn(batch.clone(), targets.clone()) - - torch.manual_seed(seed) - output_scripted = scripted_fn(batch.clone(), targets.clone()) - assert_equal(output[0], output_scripted[0]) - assert_equal(output[1], output_scripted[1]) - - fn.__repr__() - - -@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix]) -def test_random_mixupcutmix_with_invalid_data(tranform): - with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."): - tranform(0) - with pytest.raises(AssertionError, match="Alpha param can't be zero."): - tranform(10, alpha=0.0) - - t = tranform(10) - with pytest.raises(ValueError, match="Batch ndim should be 4."): - t(torch.rand(3, 60, 60), torch.randint(10, (1, ))) - with pytest.raises(ValueError, match="Target ndim should be 1."): - t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1))) - with pytest.raises(TypeError, match="Batch dtype should be a float tensor."): - t(torch.randint(256, (32, 3, 60, 60), dtype=torch.uint8), torch.randint(10, (32, ))) - with pytest.raises(TypeError, match="Target dtype should be torch.int64."): - t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32)) - - -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('transform, expected', [ - (T.RandomMixup, [60.77401351928711, 0.5151033997535706]), - (T.RandomCutmix, [70.13909912109375, 0.525851309299469]) -]) -def test_random_mixupcutmix_with_real_data(device, transform, expected): - torch.manual_seed(12) - - # Build dummy dataset - images = [] - for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]: - fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file - img = read_image(get_file_path_2(*fullpath)) - images.append(F.resize(img, [224, 224])) - dataset = TensorDataset(torch.stack(images).to(device=device, dtype=torch.float32), - torch.tensor([0, 1], device=device)) - - # Use mixup in the collate - trans = transform(2) - dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: trans(*default_collate(batch))) - - # Test against known statistics about the produced images - stats = [] - for _ in range(25): - for b, t in dataloader: - stats.append(torch.stack([b.std(), t.std()])) - - torch.testing.assert_close( - torch.stack(stats).mean(dim=0), - torch.tensor(expected, device=device) - ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a2439f7000e..8da0d016f4d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,8 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", "RandomMixup", - "RandomCutmix"] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1965,173 +1964,3 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) - - -# TODO: move this to references before merging and delete the tests -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"mixup: Beyond Empirical Risk Minimization" `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for mixup. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, - p: float = 0.5, alpha: float = 1.0, - inplace: bool = False) -> None: - super().__init__() - assert num_classes > 0, "Please provide a valid positive value for the num_classes." - assert alpha > 0, "Alpha param can't be zero." - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) - elif target.ndim != 1: - raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) - elif not batch.is_floating_point(): - raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) - elif target.dtype != torch.int64: - raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1) - - # Implemented as on mixup paper, page 3. - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_classes={num_classes}' - s += ', p={p}' - s += ', alpha={alpha}' - s += ', inplace={inplace}' - s += ')' - return s.format(**self.__dict__) - - -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" - `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for cutmix. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, - p: float = 0.5, alpha: float = 1.0, - inplace: bool = False) -> None: - super().__init__() - assert num_classes > 0, "Please provide a valid positive value for the num_classes." - assert alpha > 0, "Alpha param can't be zero." - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) - elif target.ndim != 1: - raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) - elif not batch.is_floating_point(): - raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) - elif target.dtype != torch.int64: - raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1) - - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - W, H = F.get_image_size(batch) - - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) - - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_classes={num_classes}' - s += ', p={p}' - s += ', alpha={alpha}' - s += ', inplace={inplace}' - s += ')' - return s.format(**self.__dict__) From b548b7afa99354ed1dfa965c312d22fec7f2aef7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 15 Sep 2021 17:07:59 +0100 Subject: [PATCH 16/16] Final code clean up. --- references/classification/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/references/classification/transforms.py b/references/classification/transforms.py index b3d6f9474ce..c4d83ce410c 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -54,14 +54,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: target = target.clone() if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1) + target_rolled = target.roll(1, 0) # Implemented as on mixup paper, page 3. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) @@ -132,14 +132,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: target = target.clone() if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32) + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1) + target_rolled = target.roll(1, 0) # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])