@@ -25,14 +25,16 @@ def __init__(
25
25
* ,
26
26
interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
27
27
fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] = None ,
28
+ generator = None ,
28
29
) -> None :
29
30
super ().__init__ ()
30
31
self .interpolation = _check_interpolation (interpolation )
31
32
self .fill = _setup_fill_arg (fill )
33
+ self .generator = generator
32
34
33
35
def _get_random_item (self , dct : Dict [str , Tuple [Callable , bool ]]) -> Tuple [str , Tuple [Callable , bool ]]:
34
36
keys = tuple (dct .keys ())
35
- key = keys [int (torch .randint (len (keys ), ()))]
37
+ key = keys [int (torch .randint (len (keys ), (), generator = self . generator ))]
36
38
return key , dct [key ]
37
39
38
40
def _flatten_and_extract_image_or_video (
@@ -219,8 +221,9 @@ def __init__(
219
221
policy : AutoAugmentPolicy = AutoAugmentPolicy .IMAGENET ,
220
222
interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
221
223
fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] = None ,
224
+ generator = None ,
222
225
) -> None :
223
- super ().__init__ (interpolation = interpolation , fill = fill )
226
+ super ().__init__ (interpolation = interpolation , fill = fill , generator = generator )
224
227
self .policy = policy
225
228
self ._policies = self ._get_policies (policy )
226
229
@@ -318,18 +321,18 @@ def forward(self, *inputs: Any) -> Any:
318
321
flat_inputs_with_spec , image_or_video = self ._flatten_and_extract_image_or_video (inputs )
319
322
height , width = get_size (image_or_video )
320
323
321
- policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
324
+ policy = self ._policies [int (torch .randint (len (self ._policies ), (), generator = self . generator ))]
322
325
323
326
for transform_id , probability , magnitude_idx in policy :
324
- if not torch .rand (()) <= probability :
327
+ if not torch .rand ((), generator = self . generator ) <= probability :
325
328
continue
326
329
327
330
magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
328
331
329
332
magnitudes = magnitudes_fn (10 , height , width )
330
333
if magnitudes is not None :
331
334
magnitude = float (magnitudes [magnitude_idx ])
332
- if signed and torch .rand (()) <= 0.5 :
335
+ if signed and torch .rand ((), generator = self . generator ) <= 0.5 :
333
336
magnitude *= - 1
334
337
else :
335
338
magnitude = 0.0
@@ -399,8 +402,9 @@ def __init__(
399
402
num_magnitude_bins : int = 31 ,
400
403
interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
401
404
fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] = None ,
405
+ generator = None ,
402
406
) -> None :
403
- super ().__init__ (interpolation = interpolation , fill = fill )
407
+ super ().__init__ (interpolation = interpolation , fill = fill , generator = generator )
404
408
self .num_ops = num_ops
405
409
self .magnitude = magnitude
406
410
self .num_magnitude_bins = num_magnitude_bins
@@ -414,7 +418,7 @@ def forward(self, *inputs: Any) -> Any:
414
418
magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
415
419
if magnitudes is not None :
416
420
magnitude = float (magnitudes [self .magnitude ])
417
- if signed and torch .rand (()) <= 0.5 :
421
+ if signed and torch .rand ((), generator = self . generator ) <= 0.5 :
418
422
magnitude *= - 1
419
423
else :
420
424
magnitude = 0.0
@@ -472,8 +476,9 @@ def __init__(
472
476
num_magnitude_bins : int = 31 ,
473
477
interpolation : Union [InterpolationMode , int ] = InterpolationMode .NEAREST ,
474
478
fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] = None ,
479
+ generator = None ,
475
480
):
476
- super ().__init__ (interpolation = interpolation , fill = fill )
481
+ super ().__init__ (interpolation = interpolation , fill = fill , generator = generator )
477
482
self .num_magnitude_bins = num_magnitude_bins
478
483
479
484
def forward (self , * inputs : Any ) -> Any :
@@ -484,8 +489,8 @@ def forward(self, *inputs: Any) -> Any:
484
489
485
490
magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
486
491
if magnitudes is not None :
487
- magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
488
- if signed and torch .rand (()) <= 0.5 :
492
+ magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , (), generator = self . generator ))])
493
+ if signed and torch .rand ((), generator = self . generator ) <= 0.5 :
489
494
magnitude *= - 1
490
495
else :
491
496
magnitude = 0.0
@@ -555,8 +560,9 @@ def __init__(
555
560
all_ops : bool = True ,
556
561
interpolation : Union [InterpolationMode , int ] = InterpolationMode .BILINEAR ,
557
562
fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] = None ,
563
+ generator = None ,
558
564
) -> None :
559
- super ().__init__ (interpolation = interpolation , fill = fill )
565
+ super ().__init__ (interpolation = interpolation , fill = fill , generator = generator )
560
566
self ._PARAMETER_MAX = 10
561
567
if not (1 <= severity <= self ._PARAMETER_MAX ):
562
568
raise ValueError (f"The severity must be between [1, { self ._PARAMETER_MAX } ]. Got { severity } instead." )
@@ -601,14 +607,18 @@ def forward(self, *inputs: Any) -> Any:
601
607
mix = m [:, 0 ].reshape (batch_dims ) * batch
602
608
for i in range (self .mixture_width ):
603
609
aug = batch
604
- depth = self .chain_depth if self .chain_depth > 0 else int (torch .randint (low = 1 , high = 4 , size = (1 ,)).item ())
610
+ depth = (
611
+ self .chain_depth
612
+ if self .chain_depth > 0
613
+ else int (torch .randint (low = 1 , high = 4 , size = (1 ,), generator = self .generator ).item ())
614
+ )
605
615
for _ in range (depth ):
606
616
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (augmentation_space )
607
617
608
618
magnitudes = magnitudes_fn (self ._PARAMETER_MAX , height , width )
609
619
if magnitudes is not None :
610
- magnitude = float (magnitudes [int (torch .randint (self .severity , ()))])
611
- if signed and torch .rand (()) <= 0.5 :
620
+ magnitude = float (magnitudes [int (torch .randint (self .severity , (), generator = self . generator ))])
621
+ if signed and torch .rand ((), generator = self . generator ) <= 0.5 :
612
622
magnitude *= - 1
613
623
else :
614
624
magnitude = 0.0
0 commit comments