@@ -35,7 +35,8 @@ def __init__(
35
35
antialias : Optional [bool ] = None ,
36
36
) -> None :
37
37
super ().__init__ ()
38
- self .size = [size ] if isinstance (size , int ) else list (size )
38
+
39
+ self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
39
40
self .interpolation = interpolation
40
41
self .max_size = max_size
41
42
self .antialias = antialias
@@ -80,7 +81,6 @@ def __init__(
80
81
if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
81
82
warnings .warn ("Scale and ratio should be of kind (min, max)" )
82
83
83
- self .size = size
84
84
self .scale = scale
85
85
self .ratio = ratio
86
86
self .interpolation = interpolation
@@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) ->
225
225
raise TypeError ("Got inappropriate fill arg" )
226
226
227
227
228
+ def _check_padding_arg (padding : Union [int , Sequence [int ]]) -> None :
229
+ if not isinstance (padding , (numbers .Number , tuple , list )):
230
+ raise TypeError ("Got inappropriate padding arg" )
231
+
232
+ if isinstance (padding , (tuple , list )) and len (padding ) not in [1 , 2 , 4 ]:
233
+ raise ValueError (f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple" )
234
+
235
+
236
+ def _check_padding_mode_arg (padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ]) -> None :
237
+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
238
+ raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
239
+
240
+
228
241
class Pad (Transform ):
229
242
def __init__ (
230
243
self ,
@@ -233,18 +246,10 @@ def __init__(
233
246
padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
234
247
) -> None :
235
248
super ().__init__ ()
236
- if not isinstance (padding , (numbers .Number , tuple , list )):
237
- raise TypeError ("Got inappropriate padding arg" )
238
-
239
- if isinstance (padding , (tuple , list )) and len (padding ) not in [1 , 2 , 4 ]:
240
- raise ValueError (
241
- f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple"
242
- )
243
249
250
+ _check_padding_arg (padding )
244
251
_check_fill_arg (fill )
245
-
246
- if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
247
- raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
252
+ _check_padding_mode_arg (padding_mode )
248
253
249
254
self .padding = padding
250
255
self .fill = fill
@@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
416
421
fill = self .fill ,
417
422
center = self .center ,
418
423
)
424
+
425
+
426
+ class RandomCrop (Transform ):
427
+ def __init__ (
428
+ self ,
429
+ size : Union [int , Sequence [int ]],
430
+ padding : Optional [Union [int , Sequence [int ]]] = None ,
431
+ pad_if_needed : bool = False ,
432
+ fill : Union [int , float , Sequence [int ], Sequence [float ]] = 0 ,
433
+ padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
434
+ ) -> None :
435
+ super ().__init__ ()
436
+
437
+ self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
438
+
439
+ if padding is not None :
440
+ _check_padding_arg (padding )
441
+
442
+ if (padding is not None ) or pad_if_needed :
443
+ _check_padding_mode_arg (padding_mode )
444
+ _check_fill_arg (fill )
445
+
446
+ self .padding = padding
447
+ self .pad_if_needed = pad_if_needed
448
+ self .fill = fill
449
+ self .padding_mode = padding_mode
450
+
451
+ def _get_params (self , sample : Any ) -> Dict [str , Any ]:
452
+ image = query_image (sample )
453
+ _ , height , width = get_image_dimensions (image )
454
+ output_height , output_width = self .size
455
+
456
+ if height + 1 < output_height or width + 1 < output_width :
457
+ raise ValueError (
458
+ f"Required crop size { (output_height , output_width )} is larger then input image size { (height , width )} "
459
+ )
460
+
461
+ if width == output_width and height == output_height :
462
+ return dict (top = 0 , left = 0 , height = height , width = width )
463
+
464
+ top = torch .randint (0 , height - output_height + 1 , size = (1 ,)).item ()
465
+ left = torch .randint (0 , width - output_width + 1 , size = (1 ,)).item ()
466
+ return dict (top = top , left = left , height = output_height , width = output_width )
467
+
468
+ def _forward (self , flat_inputs : List [Any ]) -> List [Any ]:
469
+ if self .padding is not None :
470
+ flat_inputs = [F .pad (flat_input , self .padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
471
+
472
+ image = query_image (flat_inputs )
473
+ _ , height , width = get_image_dimensions (image )
474
+
475
+ # pad the width if needed
476
+ if self .pad_if_needed and width < self .size [1 ]:
477
+ padding = [self .size [1 ] - width , 0 ]
478
+ flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
479
+ # pad the height if needed
480
+ if self .pad_if_needed and height < self .size [0 ]:
481
+ padding = [0 , self .size [0 ] - height ]
482
+ flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
483
+
484
+ params = self ._get_params (flat_inputs )
485
+
486
+ return [F .crop (flat_input , ** params ) for flat_input in flat_inputs ]
487
+
488
+ def forward (self , * inputs : Any ) -> Any :
489
+ from torch .utils ._pytree import tree_flatten , tree_unflatten
490
+
491
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
492
+
493
+ flat_inputs , spec = tree_flatten (sample )
494
+ out_flat_inputs = self ._forward (flat_inputs )
495
+ return tree_unflatten (out_flat_inputs , spec )
0 commit comments