Skip to content

Commit eb932b9

Browse files
committed
Moving mixup and cutmix to references.
1 parent b5bf8fc commit eb932b9

File tree

4 files changed

+179
-253
lines changed

4 files changed

+179
-253
lines changed

references/classification/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torchvision.transforms.functional import InterpolationMode
1111

1212
import presets
13+
import transforms
1314
import utils
1415

1516
try:
@@ -170,9 +171,9 @@ def main(args):
170171
num_classes = len(dataset.classes)
171172
mixup_transforms = []
172173
if args.mixup_alpha > 0.0:
173-
mixup_transforms.append(torchvision.transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
174+
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
174175
if args.cutmix_alpha > 0.0:
175-
mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
176+
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
176177
if mixup_transforms:
177178
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
178179
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import math
2+
import torch
3+
4+
from typing import Tuple
5+
from torch import Tensor
6+
from torchvision.transforms import functional as F
7+
8+
9+
class RandomMixup(torch.nn.Module):
10+
"""Randomly apply Mixup to the provided batch and targets.
11+
The class implements the data augmentations as described in the paper
12+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13+
14+
Args:
15+
num_classes (int): number of classes used for one-hot encoding.
16+
p (float): probability of the batch being transformed. Default value is 0.5.
17+
alpha (float): hyperparameter of the Beta distribution used for mixup.
18+
Default value is 1.0.
19+
inplace (bool): boolean to make this transform inplace. Default set to False.
20+
"""
21+
22+
def __init__(self, num_classes: int,
23+
p: float = 0.5, alpha: float = 1.0,
24+
inplace: bool = False) -> None:
25+
super().__init__()
26+
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
27+
assert alpha > 0, "Alpha param can't be zero."
28+
29+
self.num_classes = num_classes
30+
self.p = p
31+
self.alpha = alpha
32+
self.inplace = inplace
33+
34+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
35+
"""
36+
Args:
37+
batch (Tensor): Float tensor of size (B, C, H, W)
38+
target (Tensor): Integer tensor of size (B, )
39+
40+
Returns:
41+
Tensor: Randomly transformed batch.
42+
"""
43+
if batch.ndim != 4:
44+
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
45+
elif target.ndim != 1:
46+
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
47+
elif not batch.is_floating_point():
48+
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
49+
elif target.dtype != torch.int64:
50+
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
51+
52+
if not self.inplace:
53+
batch = batch.clone()
54+
target = target.clone()
55+
56+
if target.ndim == 1:
57+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
58+
59+
if torch.rand(1).item() >= self.p:
60+
return batch, target
61+
62+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
63+
batch_rolled = batch.roll(1, 0)
64+
target_rolled = target.roll(1)
65+
66+
# Implemented as on mixup paper, page 3.
67+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
68+
batch_rolled.mul_(1.0 - lambda_param)
69+
batch.mul_(lambda_param).add_(batch_rolled)
70+
71+
target_rolled.mul_(1.0 - lambda_param)
72+
target.mul_(lambda_param).add_(target_rolled)
73+
74+
return batch, target
75+
76+
def __repr__(self) -> str:
77+
s = self.__class__.__name__ + '('
78+
s += 'num_classes={num_classes}'
79+
s += ', p={p}'
80+
s += ', alpha={alpha}'
81+
s += ', inplace={inplace}'
82+
s += ')'
83+
return s.format(**self.__dict__)
84+
85+
86+
class RandomCutmix(torch.nn.Module):
87+
"""Randomly apply Cutmix to the provided batch and targets.
88+
The class implements the data augmentations as described in the paper
89+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
90+
<https://arxiv.org/abs/1905.04899>`_.
91+
92+
Args:
93+
num_classes (int): number of classes used for one-hot encoding.
94+
p (float): probability of the batch being transformed. Default value is 0.5.
95+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
96+
Default value is 1.0.
97+
inplace (bool): boolean to make this transform inplace. Default set to False.
98+
"""
99+
100+
def __init__(self, num_classes: int,
101+
p: float = 0.5, alpha: float = 1.0,
102+
inplace: bool = False) -> None:
103+
super().__init__()
104+
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
105+
assert alpha > 0, "Alpha param can't be zero."
106+
107+
self.num_classes = num_classes
108+
self.p = p
109+
self.alpha = alpha
110+
self.inplace = inplace
111+
112+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
113+
"""
114+
Args:
115+
batch (Tensor): Float tensor of size (B, C, H, W)
116+
target (Tensor): Integer tensor of size (B, )
117+
118+
Returns:
119+
Tensor: Randomly transformed batch.
120+
"""
121+
if batch.ndim != 4:
122+
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
123+
elif target.ndim != 1:
124+
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
125+
elif not batch.is_floating_point():
126+
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
127+
elif target.dtype != torch.int64:
128+
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
129+
130+
if not self.inplace:
131+
batch = batch.clone()
132+
target = target.clone()
133+
134+
if target.ndim == 1:
135+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
136+
137+
if torch.rand(1).item() >= self.p:
138+
return batch, target
139+
140+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
141+
batch_rolled = batch.roll(1, 0)
142+
target_rolled = target.roll(1)
143+
144+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
145+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
146+
W, H = F.get_image_size(batch)
147+
148+
r_x = torch.randint(W, (1,))
149+
r_y = torch.randint(H, (1,))
150+
151+
r = 0.5 * math.sqrt(1.0 - lambda_param)
152+
r_w_half = int(r * W)
153+
r_h_half = int(r * H)
154+
155+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
156+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
157+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
158+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
159+
160+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
161+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
162+
163+
target_rolled.mul_(1.0 - lambda_param)
164+
target.mul_(lambda_param).add_(target_rolled)
165+
166+
return batch, target
167+
168+
def __repr__(self) -> str:
169+
s = self.__class__.__name__ + '('
170+
s += 'num_classes={num_classes}'
171+
s += ', p={p}'
172+
s += ', alpha={alpha}'
173+
s += ', inplace={inplace}'
174+
s += ')'
175+
return s.format(**self.__dict__)

test/test_transforms_tensor.py

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import os
22
import torch
3-
from torch._utils_internal import get_file_path_2
4-
from torch.utils.data import TensorDataset, DataLoader
5-
from torch.utils.data.dataloader import default_collate
63
from torchvision import transforms as T
7-
from torchvision.io import read_image
84
from torchvision.transforms import functional as F
95
from torchvision.transforms import InterpolationMode
106

@@ -719,78 +715,3 @@ def test_gaussian_blur(device, meth_kwargs):
719715
T.GaussianBlur, meth_kwargs=meth_kwargs,
720716
test_exact_match=False, device=device, agg_method="max", tol=tol
721717
)
722-
723-
724-
@pytest.mark.parametrize('device', cpu_and_gpu())
725-
@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
726-
@pytest.mark.parametrize('p', [0.0, 1.0])
727-
@pytest.mark.parametrize('inplace', [True, False])
728-
def test_random_mixupcutmix(device, tranform, p, inplace):
729-
batch_size = 32
730-
num_classes = 10
731-
batch = torch.rand(batch_size, 3, 44, 56, device=device)
732-
targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64)
733-
734-
fn = tranform(num_classes, p=p, inplace=inplace)
735-
scripted_fn = torch.jit.script(fn)
736-
737-
seed = torch.seed()
738-
output = fn(batch.clone(), targets.clone())
739-
740-
torch.manual_seed(seed)
741-
output_scripted = scripted_fn(batch.clone(), targets.clone())
742-
assert_equal(output[0], output_scripted[0])
743-
assert_equal(output[1], output_scripted[1])
744-
745-
fn.__repr__()
746-
747-
748-
@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
749-
def test_random_mixupcutmix_with_invalid_data(tranform):
750-
with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."):
751-
tranform(0)
752-
with pytest.raises(AssertionError, match="Alpha param can't be zero."):
753-
tranform(10, alpha=0.0)
754-
755-
t = tranform(10)
756-
with pytest.raises(ValueError, match="Batch ndim should be 4."):
757-
t(torch.rand(3, 60, 60), torch.randint(10, (1, )))
758-
with pytest.raises(ValueError, match="Target ndim should be 1."):
759-
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1)))
760-
with pytest.raises(TypeError, match="Batch dtype should be a float tensor."):
761-
t(torch.randint(256, (32, 3, 60, 60), dtype=torch.uint8), torch.randint(10, (32, )))
762-
with pytest.raises(TypeError, match="Target dtype should be torch.int64."):
763-
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32))
764-
765-
766-
@pytest.mark.parametrize('device', cpu_and_gpu())
767-
@pytest.mark.parametrize('transform, expected', [
768-
(T.RandomMixup, [60.77401351928711, 0.5151033997535706]),
769-
(T.RandomCutmix, [70.13909912109375, 0.525851309299469])
770-
])
771-
def test_random_mixupcutmix_with_real_data(device, transform, expected):
772-
torch.manual_seed(12)
773-
774-
# Build dummy dataset
775-
images = []
776-
for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]:
777-
fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file
778-
img = read_image(get_file_path_2(*fullpath))
779-
images.append(F.resize(img, [224, 224]))
780-
dataset = TensorDataset(torch.stack(images).to(device=device, dtype=torch.float32),
781-
torch.tensor([0, 1], device=device))
782-
783-
# Use mixup in the collate
784-
trans = transform(2)
785-
dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: trans(*default_collate(batch)))
786-
787-
# Test against known statistics about the produced images
788-
stats = []
789-
for _ in range(25):
790-
for b, t in dataloader:
791-
stats.append(torch.stack([b.std(), t.std()]))
792-
793-
torch.testing.assert_close(
794-
torch.stack(stats).mean(dim=0),
795-
torch.tensor(expected, device=device)
796-
)

0 commit comments

Comments
 (0)