Skip to content

Commit c8e3b2a

Browse files
authored
Adding Mixup and Cutmix (#4379)
* Add RandomMixupCutmix. * Add test with real data. * Use dataloader and collate in the test. * Making RandomMixupCutmix JIT scriptable. * Move out label_smoothing and try roll instead of flip * Adding mixup/cutmix in references script. * Handle one-hot encoded target in accuracy. * Add support of devices on tests. * Separate Mixup from Cutmix. * Add check for floats. * Adding device on expect value. * Remove hardcoded weights. * One-hot only when necessary. * Fix linter. * Moving mixup and cutmix to references. * Final code clean up.
1 parent b096271 commit c8e3b2a

File tree

5 files changed

+210
-6
lines changed

5 files changed

+210
-6
lines changed

references/classification/train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import torch
66
import torch.utils.data
7+
from torch.utils.data.dataloader import default_collate
78
from torch import nn
89
import torchvision
910
from torchvision.transforms.functional import InterpolationMode
1011

1112
import presets
13+
import transforms
1214
import utils
1315

1416
try:
@@ -164,10 +166,21 @@ def main(args):
164166
train_dir = os.path.join(args.data_path, 'train')
165167
val_dir = os.path.join(args.data_path, 'val')
166168
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
169+
170+
collate_fn = None
171+
num_classes = len(dataset.classes)
172+
mixup_transforms = []
173+
if args.mixup_alpha > 0.0:
174+
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
175+
if args.cutmix_alpha > 0.0:
176+
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
177+
if mixup_transforms:
178+
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
179+
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
167180
data_loader = torch.utils.data.DataLoader(
168181
dataset, batch_size=args.batch_size,
169-
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
170-
182+
sampler=train_sampler, num_workers=args.workers, pin_memory=True,
183+
collate_fn=collate_fn)
171184
data_loader_test = torch.utils.data.DataLoader(
172185
dataset_test, batch_size=args.batch_size,
173186
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
@@ -272,6 +285,8 @@ def get_args_parser(add_help=True):
272285
parser.add_argument('--label-smoothing', default=0.0, type=float,
273286
help='label smoothing (default: 0.0)',
274287
dest='label_smoothing')
288+
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
289+
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
275290
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
276291
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
277292
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import math
2+
import torch
3+
4+
from typing import Tuple
5+
from torch import Tensor
6+
from torchvision.transforms import functional as F
7+
8+
9+
class RandomMixup(torch.nn.Module):
10+
"""Randomly apply Mixup to the provided batch and targets.
11+
The class implements the data augmentations as described in the paper
12+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13+
14+
Args:
15+
num_classes (int): number of classes used for one-hot encoding.
16+
p (float): probability of the batch being transformed. Default value is 0.5.
17+
alpha (float): hyperparameter of the Beta distribution used for mixup.
18+
Default value is 1.0.
19+
inplace (bool): boolean to make this transform inplace. Default set to False.
20+
"""
21+
22+
def __init__(self, num_classes: int,
23+
p: float = 0.5, alpha: float = 1.0,
24+
inplace: bool = False) -> None:
25+
super().__init__()
26+
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
27+
assert alpha > 0, "Alpha param can't be zero."
28+
29+
self.num_classes = num_classes
30+
self.p = p
31+
self.alpha = alpha
32+
self.inplace = inplace
33+
34+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
35+
"""
36+
Args:
37+
batch (Tensor): Float tensor of size (B, C, H, W)
38+
target (Tensor): Integer tensor of size (B, )
39+
40+
Returns:
41+
Tensor: Randomly transformed batch.
42+
"""
43+
if batch.ndim != 4:
44+
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
45+
elif target.ndim != 1:
46+
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
47+
elif not batch.is_floating_point():
48+
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
49+
elif target.dtype != torch.int64:
50+
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
51+
52+
if not self.inplace:
53+
batch = batch.clone()
54+
target = target.clone()
55+
56+
if target.ndim == 1:
57+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
58+
59+
if torch.rand(1).item() >= self.p:
60+
return batch, target
61+
62+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
63+
batch_rolled = batch.roll(1, 0)
64+
target_rolled = target.roll(1, 0)
65+
66+
# Implemented as on mixup paper, page 3.
67+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
68+
batch_rolled.mul_(1.0 - lambda_param)
69+
batch.mul_(lambda_param).add_(batch_rolled)
70+
71+
target_rolled.mul_(1.0 - lambda_param)
72+
target.mul_(lambda_param).add_(target_rolled)
73+
74+
return batch, target
75+
76+
def __repr__(self) -> str:
77+
s = self.__class__.__name__ + '('
78+
s += 'num_classes={num_classes}'
79+
s += ', p={p}'
80+
s += ', alpha={alpha}'
81+
s += ', inplace={inplace}'
82+
s += ')'
83+
return s.format(**self.__dict__)
84+
85+
86+
class RandomCutmix(torch.nn.Module):
87+
"""Randomly apply Cutmix to the provided batch and targets.
88+
The class implements the data augmentations as described in the paper
89+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
90+
<https://arxiv.org/abs/1905.04899>`_.
91+
92+
Args:
93+
num_classes (int): number of classes used for one-hot encoding.
94+
p (float): probability of the batch being transformed. Default value is 0.5.
95+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
96+
Default value is 1.0.
97+
inplace (bool): boolean to make this transform inplace. Default set to False.
98+
"""
99+
100+
def __init__(self, num_classes: int,
101+
p: float = 0.5, alpha: float = 1.0,
102+
inplace: bool = False) -> None:
103+
super().__init__()
104+
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
105+
assert alpha > 0, "Alpha param can't be zero."
106+
107+
self.num_classes = num_classes
108+
self.p = p
109+
self.alpha = alpha
110+
self.inplace = inplace
111+
112+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
113+
"""
114+
Args:
115+
batch (Tensor): Float tensor of size (B, C, H, W)
116+
target (Tensor): Integer tensor of size (B, )
117+
118+
Returns:
119+
Tensor: Randomly transformed batch.
120+
"""
121+
if batch.ndim != 4:
122+
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
123+
elif target.ndim != 1:
124+
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
125+
elif not batch.is_floating_point():
126+
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
127+
elif target.dtype != torch.int64:
128+
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
129+
130+
if not self.inplace:
131+
batch = batch.clone()
132+
target = target.clone()
133+
134+
if target.ndim == 1:
135+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
136+
137+
if torch.rand(1).item() >= self.p:
138+
return batch, target
139+
140+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
141+
batch_rolled = batch.roll(1, 0)
142+
target_rolled = target.roll(1, 0)
143+
144+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
145+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
146+
W, H = F.get_image_size(batch)
147+
148+
r_x = torch.randint(W, (1,))
149+
r_y = torch.randint(H, (1,))
150+
151+
r = 0.5 * math.sqrt(1.0 - lambda_param)
152+
r_w_half = int(r * W)
153+
r_h_half = int(r * H)
154+
155+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
156+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
157+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
158+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
159+
160+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
161+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
162+
163+
target_rolled.mul_(1.0 - lambda_param)
164+
target.mul_(lambda_param).add_(target_rolled)
165+
166+
return batch, target
167+
168+
def __repr__(self) -> str:
169+
s = self.__class__.__name__ + '('
170+
s += 'num_classes={num_classes}'
171+
s += ', p={p}'
172+
s += ', alpha={alpha}'
173+
s += ', inplace={inplace}'
174+
s += ')'
175+
return s.format(**self.__dict__)

references/classification/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def accuracy(output, target, topk=(1,)):
189189
with torch.no_grad():
190190
maxk = max(topk)
191191
batch_size = target.size(0)
192+
if target.ndim == 2:
193+
target = target.max(dim=1)[1]
192194

193195
_, pred = output.topk(maxk, 1, True, True)
194196
pred = pred.t()

test/test_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,8 @@ def test_random_choice():
13111311
transforms.Resize(15),
13121312
transforms.Resize(20),
13131313
transforms.CenterCrop(10)
1314-
]
1314+
],
1315+
[1 / 3, 1 / 3, 1 / 3]
13151316
)
13161317
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
13171318
num_samples = 250

torchvision/transforms/transforms.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,20 @@ def __call__(self, img):
515515
class RandomChoice(RandomTransforms):
516516
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
517517
"""
518-
def __call__(self, img):
519-
t = random.choice(self.transforms)
520-
return t(img)
518+
def __init__(self, transforms, p=None):
519+
super().__init__(transforms)
520+
if p is not None and not isinstance(p, Sequence):
521+
raise TypeError("Argument transforms should be a sequence")
522+
self.p = p
523+
524+
def __call__(self, *args):
525+
t = random.choices(self.transforms, weights=self.p)[0]
526+
return t(*args)
527+
528+
def __repr__(self):
529+
format_string = super().__repr__()
530+
format_string += '(p={0})'.format(self.p)
531+
return format_string
521532

522533

523534
class RandomCrop(torch.nn.Module):

0 commit comments

Comments
 (0)