22
22
"RandomHorizontalFlip" , "RandomVerticalFlip" , "RandomResizedCrop" , "RandomSizedCrop" , "FiveCrop" , "TenCrop" ,
23
23
"LinearTransformation" , "ColorJitter" , "RandomRotation" , "RandomAffine" , "Grayscale" , "RandomGrayscale" ,
24
24
"RandomPerspective" , "RandomErasing" , "GaussianBlur" , "InterpolationMode" , "RandomInvert" , "RandomPosterize" ,
25
- "RandomSolarize" , "RandomAdjustSharpness" , "RandomAutocontrast" , "RandomEqualize" , 'RandomMixupCutmix' ]
25
+ "RandomSolarize" , "RandomAdjustSharpness" , "RandomAutocontrast" , "RandomEqualize" , 'RandomMixup' ,
26
+ "RandomCutmix" ]
26
27
27
28
28
29
class Compose :
@@ -515,9 +516,20 @@ def __call__(self, img):
515
516
class RandomChoice (RandomTransforms ):
516
517
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
517
518
"""
518
- def __call__ (self , img ):
519
- t = random .choice (self .transforms )
520
- return t (img )
519
+ def __init__ (self , transforms , p = None ):
520
+ super ().__init__ (transforms )
521
+ if p is not None and not isinstance (p , Sequence ):
522
+ raise TypeError ("Argument transforms should be a sequence" )
523
+ self .p = p
524
+
525
+ def __call__ (self , * args ):
526
+ t = random .choices (self .transforms , weights = self .p )[0 ]
527
+ return t (* args )
528
+
529
+ def __repr__ (self ):
530
+ format_string = super ().__repr__ ()
531
+ format_string += '(p={0})' .format (self .p )
532
+ return format_string
521
533
522
534
523
535
class RandomCrop (torch .nn .Module ):
@@ -1956,38 +1968,103 @@ def __repr__(self):
1956
1968
1957
1969
1958
1970
# TODO: move this to references before merging and delete the tests
1959
- class RandomMixupCutmix (torch .nn .Module ):
1960
- """Randomly apply Mixup or Cutmix to the provided batch and targets.
1961
- The class implements the data augmentations as described in the papers
1962
- `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
1971
+ class RandomMixup (torch .nn .Module ):
1972
+ """Randomly apply Mixup to the provided batch and targets.
1973
+ The class implements the data augmentations as described in the paper
1974
+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
1975
+
1976
+ Args:
1977
+ num_classes (int): number of classes used for one-hot encoding.
1978
+ p (float): probability of the batch being transformed. Default value is 0.5.
1979
+ alpha (float): hyperparameter of the Beta distribution used for mixup.
1980
+ Default value is 1.0.
1981
+ inplace (bool): boolean to make this transform inplace. Default set to False.
1982
+ """
1983
+
1984
+ def __init__ (self , num_classes : int ,
1985
+ p : float = 0.5 , alpha : float = 1.0 ,
1986
+ inplace : bool = False ) -> None :
1987
+ super ().__init__ ()
1988
+ assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
1989
+ assert alpha > 0 , "Alpha param can't be zero."
1990
+
1991
+ self .num_classes = num_classes
1992
+ self .p = p
1993
+ self .alpha = alpha
1994
+ self .inplace = inplace
1995
+
1996
+ def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
1997
+ """
1998
+ Args:
1999
+ batch (Tensor): Float tensor of size (B, C, H, W)
2000
+ target (Tensor): Integer tensor of size (B, )
2001
+
2002
+ Returns:
2003
+ Tensor: Randomly transformed batch.
2004
+ """
2005
+ if batch .ndim != 4 :
2006
+ raise ValueError ("Batch ndim should be 4. Got {}" .format (batch .ndim ))
2007
+ elif target .ndim != 1 :
2008
+ raise ValueError ("Target ndim should be 1. Got {}" .format (target .ndim ))
2009
+ elif target .dtype != torch .int64 :
2010
+ raise ValueError ("Target dtype should be torch.int64. Got {}" .format (target .dtype ))
2011
+
2012
+ if not self .inplace :
2013
+ batch = batch .clone ()
2014
+ # target = target.clone()
2015
+
2016
+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch .float32 )
2017
+ if torch .rand (1 ).item () >= self .p :
2018
+ return batch , target
2019
+
2020
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
2021
+ batch_rolled = batch .roll (1 , 0 )
2022
+ target_rolled = target .roll (1 )
2023
+
2024
+ # Implemented as on mixup paper, page 3.
2025
+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
2026
+ batch_rolled .mul_ (1.0 - lambda_param )
2027
+ batch .mul_ (lambda_param ).add_ (batch_rolled )
2028
+
2029
+ target_rolled .mul_ (1.0 - lambda_param )
2030
+ target .mul_ (lambda_param ).add_ (target_rolled )
2031
+
2032
+ return batch , target
2033
+
2034
+ def __repr__ (self ) -> str :
2035
+ s = self .__class__ .__name__ + '('
2036
+ s += 'num_classes={num_classes}'
2037
+ s += ', p={p}'
2038
+ s += ', alpha={alpha}'
2039
+ s += ', inplace={inplace}'
2040
+ s += ')'
2041
+ return s .format (** self .__dict__ )
2042
+
2043
+
2044
+ class RandomCutmix (torch .nn .Module ):
2045
+ """Randomly apply Cutmix to the provided batch and targets.
2046
+ The class implements the data augmentations as described in the paper
1963
2047
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
1964
2048
<https://arxiv.org/abs/1905.04899>`_.
1965
2049
1966
2050
Args:
1967
2051
num_classes (int): number of classes used for one-hot encoding.
1968
- p (float): probability of the batch being transformed. Default value is 1.0.
1969
- mixup_alpha (float): hyperparameter of the Beta distribution used for mixup.
1970
- Set to 0.0 to turn off. Default value is 1.0.
1971
- cutmix_p (float): probability of using cutmix instead of mixup when both are on.
1972
- Default value is 0.5.
1973
- cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
1974
- Set to 0.0 to turn off. Default value is 0.0.
2052
+ p (float): probability of the batch being transformed. Default value is 0.5.
2053
+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
2054
+ Default value is 1.0.
1975
2055
inplace (bool): boolean to make this transform inplace. Default set to False.
1976
2056
"""
1977
2057
1978
2058
def __init__ (self , num_classes : int ,
1979
- p : float = 1.0 , mixup_alpha : float = 1.0 ,
1980
- cutmix_p : float = 0.5 , cutmix_alpha : float = 0.0 ,
2059
+ p : float = 0.5 , alpha : float = 1.0 ,
1981
2060
inplace : bool = False ) -> None :
1982
2061
super ().__init__ ()
1983
2062
assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
1984
- assert mixup_alpha > 0 or cutmix_alpha > 0 , "Both alpha params can't be zero."
2063
+ assert alpha > 0 , "Alpha param can't be zero."
1985
2064
1986
2065
self .num_classes = num_classes
1987
2066
self .p = p
1988
- self .mixup_alpha = mixup_alpha
1989
- self .cutmix_p = cutmix_p
1990
- self .cutmix_alpha = cutmix_alpha
2067
+ self .alpha = alpha
1991
2068
self .inplace = inplace
1992
2069
1993
2070
def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
@@ -2018,35 +2095,24 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
2018
2095
batch_rolled = batch .roll (1 , 0 )
2019
2096
target_rolled = target .roll (1 )
2020
2097
2021
- if self .mixup_alpha <= 0.0 :
2022
- use_mixup = False
2023
- else :
2024
- use_mixup = self .cutmix_alpha <= 0.0 or torch .rand (1 ).item () >= self .cutmix_p
2025
-
2026
- if use_mixup :
2027
- # Implemented as on mixup paper, page 3.
2028
- lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .mixup_alpha , self .mixup_alpha ]))[0 ])
2029
- batch_rolled .mul_ (1.0 - lambda_param )
2030
- batch .mul_ (lambda_param ).add_ (batch_rolled )
2031
- else :
2032
- # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2033
- lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .cutmix_alpha , self .cutmix_alpha ]))[0 ])
2034
- W , H = F .get_image_size (batch )
2098
+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2099
+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
2100
+ W , H = F .get_image_size (batch )
2035
2101
2036
- r_x = torch .randint (W , (1 ,))
2037
- r_y = torch .randint (H , (1 ,))
2102
+ r_x = torch .randint (W , (1 ,))
2103
+ r_y = torch .randint (H , (1 ,))
2038
2104
2039
- r = 0.5 * math .sqrt (1.0 - lambda_param )
2040
- r_w_half = int (r * W )
2041
- r_h_half = int (r * H )
2105
+ r = 0.5 * math .sqrt (1.0 - lambda_param )
2106
+ r_w_half = int (r * W )
2107
+ r_h_half = int (r * H )
2042
2108
2043
- x1 = int (torch .clamp (r_x - r_w_half , min = 0 ))
2044
- y1 = int (torch .clamp (r_y - r_h_half , min = 0 ))
2045
- x2 = int (torch .clamp (r_x + r_w_half , max = W ))
2046
- y2 = int (torch .clamp (r_y + r_h_half , max = H ))
2109
+ x1 = int (torch .clamp (r_x - r_w_half , min = 0 ))
2110
+ y1 = int (torch .clamp (r_y - r_h_half , min = 0 ))
2111
+ x2 = int (torch .clamp (r_x + r_w_half , max = W ))
2112
+ y2 = int (torch .clamp (r_y + r_h_half , max = H ))
2047
2113
2048
- batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
2049
- lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
2114
+ batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
2115
+ lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
2050
2116
2051
2117
target_rolled .mul_ (1.0 - lambda_param )
2052
2118
target .mul_ (lambda_param ).add_ (target_rolled )
@@ -2057,9 +2123,7 @@ def __repr__(self) -> str:
2057
2123
s = self .__class__ .__name__ + '('
2058
2124
s += 'num_classes={num_classes}'
2059
2125
s += ', p={p}'
2060
- s += ', mixup_alpha={mixup_alpha}'
2061
- s += ', cutmix_p={cutmix_p}'
2062
- s += ', cutmix_alpha={cutmix_alpha}'
2126
+ s += ', alpha={alpha}'
2063
2127
s += ', inplace={inplace}'
2064
2128
s += ')'
2065
2129
return s .format (** self .__dict__ )
0 commit comments