Skip to content

Commit 3f19902

Browse files
committed
Separate Mixup from Cutmix.
1 parent 0e128f3 commit 3f19902

File tree

4 files changed

+143
-73
lines changed

4 files changed

+143
-73
lines changed

references/classification/train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,14 @@ def main(args):
168168
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
169169

170170
collate_fn = None
171-
if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0:
172-
mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha,
173-
cutmix_alpha=args.cutmix_alpha)
171+
num_classes = len(dataset.classes)
172+
mixup_transforms = []
173+
if args.mixup_alpha > 0.0:
174+
mixup_transforms.append(torchvision.transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
175+
if args.cutmix_alpha > 0.0:
176+
mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
177+
if mixup_transforms:
178+
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms, p=[0.5, 0.5])
174179
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
175180
data_loader = torch.utils.data.DataLoader(
176181
dataset, batch_size=args.batch_size,

test/test_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,8 @@ def test_random_choice():
13111311
transforms.Resize(15),
13121312
transforms.Resize(20),
13131313
transforms.CenterCrop(10)
1314-
]
1314+
],
1315+
[1 / 3, 1 / 3, 1 / 3]
13151316
)
13161317
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
13171318
num_samples = 250

test/test_transforms_tensor.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torch._utils_internal import get_file_path_2
44
from torch.utils.data import TensorDataset, DataLoader
5+
from torch.utils.data.dataloader import default_collate
56
from torchvision import transforms as T
67
from torchvision.io import read_image
78
from torchvision.transforms import functional as F
@@ -721,21 +722,16 @@ def test_gaussian_blur(device, meth_kwargs):
721722

722723

723724
@pytest.mark.parametrize('device', cpu_and_gpu())
724-
@pytest.mark.parametrize('alphas', [
725-
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 1.0},
726-
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 0.0},
727-
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'p': 0.0},
728-
{"mixup_alpha": 0.0, "cutmix_alpha": 1.0},
729-
{"mixup_alpha": 1.0, "cutmix_alpha": 0.0},
730-
])
725+
@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
726+
@pytest.mark.parametrize('p', [0.0, 1.0])
731727
@pytest.mark.parametrize('inplace', [True, False])
732-
def test_random_mixupcutmix(device, alphas, inplace):
728+
def test_random_mixupcutmix(device, tranform, p, inplace):
733729
batch_size = 32
734730
num_classes = 10
735731
batch = torch.rand(batch_size, 3, 44, 56, device=device)
736732
targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64)
737733

738-
fn = T.RandomMixupCutmix(num_classes, inplace=inplace, **alphas)
734+
fn = tranform(num_classes, p=p, inplace=inplace)
739735
scripted_fn = torch.jit.script(fn)
740736

741737
seed = torch.seed()
@@ -749,13 +745,14 @@ def test_random_mixupcutmix(device, alphas, inplace):
749745
fn.__repr__()
750746

751747

752-
def test_random_mixupcutmix_with_invalid_data():
748+
@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
749+
def test_random_mixupcutmix_with_invalid_data(tranform):
753750
with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."):
754-
T.RandomMixupCutmix(0)
755-
with pytest.raises(AssertionError, match="Both alpha params can't be zero."):
756-
T.RandomMixupCutmix(10, mixup_alpha=0.0, cutmix_alpha=0.0)
751+
tranform(0)
752+
with pytest.raises(AssertionError, match="Alpha param can't be zero."):
753+
tranform(10, alpha=0.0)
757754

758-
t = T.RandomMixupCutmix(10)
755+
t = tranform(10)
759756
with pytest.raises(ValueError, match="Batch ndim should be 4."):
760757
t(torch.rand(3, 60, 60), torch.randint(10, (1, )))
761758
with pytest.raises(ValueError, match="Target ndim should be 1."):
@@ -765,7 +762,11 @@ def test_random_mixupcutmix_with_invalid_data():
765762

766763

767764
@pytest.mark.parametrize('device', cpu_and_gpu())
768-
def test_random_mixupcutmix_with_real_data(device):
765+
@pytest.mark.parametrize('transform, expected', [
766+
(T.RandomMixup, [60.77401351928711, 0.5151033997535706]),
767+
(T.RandomCutmix, [70.13909912109375, 0.525851309299469])
768+
])
769+
def test_random_mixupcutmix_with_real_data(device, transform, expected):
769770
torch.manual_seed(12)
770771

771772
# Build dummy dataset
@@ -778,17 +779,16 @@ def test_random_mixupcutmix_with_real_data(device):
778779
torch.tensor([0, 1], device=device))
779780

780781
# Use mixup in the collate
781-
mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0)
782-
dataloader = DataLoader(dataset, batch_size=2,
783-
collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch))))
782+
trans = transform(2)
783+
dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: trans(*default_collate(batch)))
784784

785785
# Test against known statistics about the produced images
786786
stats = []
787787
for _ in range(25):
788788
for b, t in dataloader:
789-
stats.append(torch.stack([b.mean(), b.std(), t.std()]))
789+
stats.append(torch.stack([b.std(), t.std()]))
790790

791791
torch.testing.assert_close(
792792
torch.stack(stats).mean(dim=0),
793-
torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751])
793+
torch.tensor(expected)
794794
)

torchvision/transforms/transforms.py

Lines changed: 113 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
2424
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
25-
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixupCutmix']
25+
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixup',
26+
"RandomCutmix"]
2627

2728

2829
class Compose:
@@ -515,9 +516,20 @@ def __call__(self, img):
515516
class RandomChoice(RandomTransforms):
516517
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
517518
"""
518-
def __call__(self, img):
519-
t = random.choice(self.transforms)
520-
return t(img)
519+
def __init__(self, transforms, p=None):
520+
super().__init__(transforms)
521+
if p is not None and not isinstance(p, Sequence):
522+
raise TypeError("Argument transforms should be a sequence")
523+
self.p = p
524+
525+
def __call__(self, *args):
526+
t = random.choices(self.transforms, weights=self.p)[0]
527+
return t(*args)
528+
529+
def __repr__(self):
530+
format_string = super().__repr__()
531+
format_string += '(p={0})'.format(self.p)
532+
return format_string
521533

522534

523535
class RandomCrop(torch.nn.Module):
@@ -1956,38 +1968,103 @@ def __repr__(self):
19561968

19571969

19581970
# TODO: move this to references before merging and delete the tests
1959-
class RandomMixupCutmix(torch.nn.Module):
1960-
"""Randomly apply Mixup or Cutmix to the provided batch and targets.
1961-
The class implements the data augmentations as described in the papers
1962-
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
1971+
class RandomMixup(torch.nn.Module):
1972+
"""Randomly apply Mixup to the provided batch and targets.
1973+
The class implements the data augmentations as described in the paper
1974+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
1975+
1976+
Args:
1977+
num_classes (int): number of classes used for one-hot encoding.
1978+
p (float): probability of the batch being transformed. Default value is 0.5.
1979+
alpha (float): hyperparameter of the Beta distribution used for mixup.
1980+
Default value is 1.0.
1981+
inplace (bool): boolean to make this transform inplace. Default set to False.
1982+
"""
1983+
1984+
def __init__(self, num_classes: int,
1985+
p: float = 0.5, alpha: float = 1.0,
1986+
inplace: bool = False) -> None:
1987+
super().__init__()
1988+
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
1989+
assert alpha > 0, "Alpha param can't be zero."
1990+
1991+
self.num_classes = num_classes
1992+
self.p = p
1993+
self.alpha = alpha
1994+
self.inplace = inplace
1995+
1996+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
1997+
"""
1998+
Args:
1999+
batch (Tensor): Float tensor of size (B, C, H, W)
2000+
target (Tensor): Integer tensor of size (B, )
2001+
2002+
Returns:
2003+
Tensor: Randomly transformed batch.
2004+
"""
2005+
if batch.ndim != 4:
2006+
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
2007+
elif target.ndim != 1:
2008+
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
2009+
elif target.dtype != torch.int64:
2010+
raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype))
2011+
2012+
if not self.inplace:
2013+
batch = batch.clone()
2014+
# target = target.clone()
2015+
2016+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
2017+
if torch.rand(1).item() >= self.p:
2018+
return batch, target
2019+
2020+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
2021+
batch_rolled = batch.roll(1, 0)
2022+
target_rolled = target.roll(1)
2023+
2024+
# Implemented as on mixup paper, page 3.
2025+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
2026+
batch_rolled.mul_(1.0 - lambda_param)
2027+
batch.mul_(lambda_param).add_(batch_rolled)
2028+
2029+
target_rolled.mul_(1.0 - lambda_param)
2030+
target.mul_(lambda_param).add_(target_rolled)
2031+
2032+
return batch, target
2033+
2034+
def __repr__(self) -> str:
2035+
s = self.__class__.__name__ + '('
2036+
s += 'num_classes={num_classes}'
2037+
s += ', p={p}'
2038+
s += ', alpha={alpha}'
2039+
s += ', inplace={inplace}'
2040+
s += ')'
2041+
return s.format(**self.__dict__)
2042+
2043+
2044+
class RandomCutmix(torch.nn.Module):
2045+
"""Randomly apply Cutmix to the provided batch and targets.
2046+
The class implements the data augmentations as described in the paper
19632047
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
19642048
<https://arxiv.org/abs/1905.04899>`_.
19652049
19662050
Args:
19672051
num_classes (int): number of classes used for one-hot encoding.
1968-
p (float): probability of the batch being transformed. Default value is 1.0.
1969-
mixup_alpha (float): hyperparameter of the Beta distribution used for mixup.
1970-
Set to 0.0 to turn off. Default value is 1.0.
1971-
cutmix_p (float): probability of using cutmix instead of mixup when both are on.
1972-
Default value is 0.5.
1973-
cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
1974-
Set to 0.0 to turn off. Default value is 0.0.
2052+
p (float): probability of the batch being transformed. Default value is 0.5.
2053+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
2054+
Default value is 1.0.
19752055
inplace (bool): boolean to make this transform inplace. Default set to False.
19762056
"""
19772057

19782058
def __init__(self, num_classes: int,
1979-
p: float = 1.0, mixup_alpha: float = 1.0,
1980-
cutmix_p: float = 0.5, cutmix_alpha: float = 0.0,
2059+
p: float = 0.5, alpha: float = 1.0,
19812060
inplace: bool = False) -> None:
19822061
super().__init__()
19832062
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
1984-
assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero."
2063+
assert alpha > 0, "Alpha param can't be zero."
19852064

19862065
self.num_classes = num_classes
19872066
self.p = p
1988-
self.mixup_alpha = mixup_alpha
1989-
self.cutmix_p = cutmix_p
1990-
self.cutmix_alpha = cutmix_alpha
2067+
self.alpha = alpha
19912068
self.inplace = inplace
19922069

19932070
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
@@ -2018,35 +2095,24 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20182095
batch_rolled = batch.roll(1, 0)
20192096
target_rolled = target.roll(1)
20202097

2021-
if self.mixup_alpha <= 0.0:
2022-
use_mixup = False
2023-
else:
2024-
use_mixup = self.cutmix_alpha <= 0.0 or torch.rand(1).item() >= self.cutmix_p
2025-
2026-
if use_mixup:
2027-
# Implemented as on mixup paper, page 3.
2028-
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0])
2029-
batch_rolled.mul_(1.0 - lambda_param)
2030-
batch.mul_(lambda_param).add_(batch_rolled)
2031-
else:
2032-
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2033-
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0])
2034-
W, H = F.get_image_size(batch)
2098+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2099+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
2100+
W, H = F.get_image_size(batch)
20352101

2036-
r_x = torch.randint(W, (1,))
2037-
r_y = torch.randint(H, (1,))
2102+
r_x = torch.randint(W, (1,))
2103+
r_y = torch.randint(H, (1,))
20382104

2039-
r = 0.5 * math.sqrt(1.0 - lambda_param)
2040-
r_w_half = int(r * W)
2041-
r_h_half = int(r * H)
2105+
r = 0.5 * math.sqrt(1.0 - lambda_param)
2106+
r_w_half = int(r * W)
2107+
r_h_half = int(r * H)
20422108

2043-
x1 = int(torch.clamp(r_x - r_w_half, min=0))
2044-
y1 = int(torch.clamp(r_y - r_h_half, min=0))
2045-
x2 = int(torch.clamp(r_x + r_w_half, max=W))
2046-
y2 = int(torch.clamp(r_y + r_h_half, max=H))
2109+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
2110+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
2111+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
2112+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
20472113

2048-
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
2049-
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
2114+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
2115+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
20502116

20512117
target_rolled.mul_(1.0 - lambda_param)
20522118
target.mul_(lambda_param).add_(target_rolled)
@@ -2057,9 +2123,7 @@ def __repr__(self) -> str:
20572123
s = self.__class__.__name__ + '('
20582124
s += 'num_classes={num_classes}'
20592125
s += ', p={p}'
2060-
s += ', mixup_alpha={mixup_alpha}'
2061-
s += ', cutmix_p={cutmix_p}'
2062-
s += ', cutmix_alpha={cutmix_alpha}'
2126+
s += ', alpha={alpha}'
20632127
s += ', inplace={inplace}'
20642128
s += ')'
20652129
return s.format(**self.__dict__)

0 commit comments

Comments
 (0)