@@ -641,12 +641,14 @@ def backward(ctx, grad_output):
641
641
assert torch .autograd .gradcheck (F .apply , x , eps = 1e-8 , atol = 1e-6 , rtol = 1e-6 , fast_mode = False )
642
642
643
643
644
- def check_functional_vs_PIL_vs_scripted (fn , fn_pil , fn_t , config , device , dtype , tol = 2.0 + 1e-10 , agg_method = "max" ):
644
+ def check_functional_vs_PIL_vs_scripted (
645
+ fn , fn_pil , fn_t , config , device , dtype , channels = 3 , tol = 2.0 + 1e-10 , agg_method = "max"
646
+ ):
645
647
646
648
script_fn = torch .jit .script (fn )
647
649
torch .manual_seed (15 )
648
- tensor , pil_img = _create_data (26 , 34 , device = device )
649
- batch_tensors = _create_data_batch (16 , 18 , num_samples = 4 , device = device )
650
+ tensor , pil_img = _create_data (26 , 34 , channels = channels , device = device )
651
+ batch_tensors = _create_data_batch (16 , 18 , num_samples = 4 , channels = channels , device = device )
650
652
651
653
if dtype is not None :
652
654
tensor = F .convert_image_dtype (tensor , dtype )
@@ -798,14 +800,16 @@ def test_equalize(device):
798
800
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
799
801
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
800
802
@pytest .mark .parametrize ('config' , [{"contrast_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]])
801
- def test_adjust_contrast (device , dtype , config ):
803
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
804
+ def test_adjust_contrast (device , dtype , config , channels ):
802
805
check_functional_vs_PIL_vs_scripted (
803
806
F .adjust_contrast ,
804
807
F_pil .adjust_contrast ,
805
808
F_t .adjust_contrast ,
806
809
config ,
807
810
device ,
808
- dtype
811
+ dtype ,
812
+ channels = channels
809
813
)
810
814
811
815
0 commit comments