1
+ import math
2
+ import torch
3
+
4
+ from typing import Tuple
5
+ from torch import Tensor
6
+
7
+
8
+ class RandomMixup (torch .nn .Module ):
9
+ """Randomly apply Mixup to the provided batch and targets.
10
+ The class implements the data augmentations as described in the paper
11
+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
12
+
13
+ Args:
14
+ num_classes (int): number of classes used for one-hot encoding.
15
+ p (float): probability of the batch being transformed. Default value is 0.5.
16
+ alpha (float): hyperparameter of the Beta distribution used for mixup.
17
+ Default value is 1.0.
18
+ inplace (bool): boolean to make this transform inplace. Default set to False.
19
+ """
20
+
21
+ def __init__ (self , num_classes : int ,
22
+ p : float = 0.5 , alpha : float = 1.0 ,
23
+ inplace : bool = False ) -> None :
24
+ super ().__init__ ()
25
+ assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
26
+ assert alpha > 0 , "Alpha param can't be zero."
27
+
28
+ self .num_classes = num_classes
29
+ self .p = p
30
+ self .alpha = alpha
31
+ self .inplace = inplace
32
+
33
+ def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
34
+ """
35
+ Args:
36
+ batch (Tensor): Float tensor of size (B, C, H, W)
37
+ target (Tensor): Integer tensor of size (B, )
38
+
39
+ Returns:
40
+ Tensor: Randomly transformed batch.
41
+ """
42
+ if batch .ndim != 4 :
43
+ raise ValueError ("Batch ndim should be 4. Got {}" .format (batch .ndim ))
44
+ elif target .ndim != 1 :
45
+ raise ValueError ("Target ndim should be 1. Got {}" .format (target .ndim ))
46
+ elif not batch .is_floating_point ():
47
+ raise TypeError ('Batch dtype should be a float tensor. Got {}.' .format (batch .dtype ))
48
+ elif target .dtype != torch .int64 :
49
+ raise TypeError ("Target dtype should be torch.int64. Got {}" .format (target .dtype ))
50
+
51
+ if not self .inplace :
52
+ batch = batch .clone ()
53
+ target = target .clone ()
54
+
55
+ if target .ndim == 1 :
56
+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch .float32 )
57
+
58
+ if torch .rand (1 ).item () >= self .p :
59
+ return batch , target
60
+
61
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
62
+ batch_rolled = batch .roll (1 , 0 )
63
+ target_rolled = target .roll (1 )
64
+
65
+ # Implemented as on mixup paper, page 3.
66
+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
67
+ batch_rolled .mul_ (1.0 - lambda_param )
68
+ batch .mul_ (lambda_param ).add_ (batch_rolled )
69
+
70
+ target_rolled .mul_ (1.0 - lambda_param )
71
+ target .mul_ (lambda_param ).add_ (target_rolled )
72
+
73
+ return batch , target
74
+
75
+ def __repr__ (self ) -> str :
76
+ s = self .__class__ .__name__ + '('
77
+ s += 'num_classes={num_classes}'
78
+ s += ', p={p}'
79
+ s += ', alpha={alpha}'
80
+ s += ', inplace={inplace}'
81
+ s += ')'
82
+ return s .format (** self .__dict__ )
83
+
84
+
85
+ class RandomCutmix (torch .nn .Module ):
86
+ """Randomly apply Cutmix to the provided batch and targets.
87
+ The class implements the data augmentations as described in the paper
88
+ `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
89
+ <https://arxiv.org/abs/1905.04899>`_.
90
+
91
+ Args:
92
+ num_classes (int): number of classes used for one-hot encoding.
93
+ p (float): probability of the batch being transformed. Default value is 0.5.
94
+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
95
+ Default value is 1.0.
96
+ inplace (bool): boolean to make this transform inplace. Default set to False.
97
+ """
98
+
99
+ def __init__ (self , num_classes : int ,
100
+ p : float = 0.5 , alpha : float = 1.0 ,
101
+ inplace : bool = False ) -> None :
102
+ super ().__init__ ()
103
+ assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
104
+ assert alpha > 0 , "Alpha param can't be zero."
105
+
106
+ self .num_classes = num_classes
107
+ self .p = p
108
+ self .alpha = alpha
109
+ self .inplace = inplace
110
+
111
+ def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
112
+ """
113
+ Args:
114
+ batch (Tensor): Float tensor of size (B, C, H, W)
115
+ target (Tensor): Integer tensor of size (B, )
116
+
117
+ Returns:
118
+ Tensor: Randomly transformed batch.
119
+ """
120
+ if batch .ndim != 4 :
121
+ raise ValueError ("Batch ndim should be 4. Got {}" .format (batch .ndim ))
122
+ elif target .ndim != 1 :
123
+ raise ValueError ("Target ndim should be 1. Got {}" .format (target .ndim ))
124
+ elif not batch .is_floating_point ():
125
+ raise TypeError ('Batch dtype should be a float tensor. Got {}.' .format (batch .dtype ))
126
+ elif target .dtype != torch .int64 :
127
+ raise TypeError ("Target dtype should be torch.int64. Got {}" .format (target .dtype ))
128
+
129
+ if not self .inplace :
130
+ batch = batch .clone ()
131
+ target = target .clone ()
132
+
133
+ if target .ndim == 1 :
134
+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch .float32 )
135
+
136
+ if torch .rand (1 ).item () >= self .p :
137
+ return batch , target
138
+
139
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
140
+ batch_rolled = batch .roll (1 , 0 )
141
+ target_rolled = target .roll (1 )
142
+
143
+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
144
+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
145
+ W , H = F .get_image_size (batch )
146
+
147
+ r_x = torch .randint (W , (1 ,))
148
+ r_y = torch .randint (H , (1 ,))
149
+
150
+ r = 0.5 * math .sqrt (1.0 - lambda_param )
151
+ r_w_half = int (r * W )
152
+ r_h_half = int (r * H )
153
+
154
+ x1 = int (torch .clamp (r_x - r_w_half , min = 0 ))
155
+ y1 = int (torch .clamp (r_y - r_h_half , min = 0 ))
156
+ x2 = int (torch .clamp (r_x + r_w_half , max = W ))
157
+ y2 = int (torch .clamp (r_y + r_h_half , max = H ))
158
+
159
+ batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
160
+ lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
161
+
162
+ target_rolled .mul_ (1.0 - lambda_param )
163
+ target .mul_ (lambda_param ).add_ (target_rolled )
164
+
165
+ return batch , target
166
+
167
+ def __repr__ (self ) -> str :
168
+ s = self .__class__ .__name__ + '('
169
+ s += 'num_classes={num_classes}'
170
+ s += ', p={p}'
171
+ s += ', alpha={alpha}'
172
+ s += ', inplace={inplace}'
173
+ s += ')'
174
+ return s .format (** self .__dict__ )
0 commit comments