@@ -896,7 +896,7 @@ def _make_input(
896
896
return input
897
897
898
898
def _adapt_fill (self , value , * , dtype ):
899
- """Adapt `` fill`` values in the range `` [0.0, 1.0]`` to the value range of the dtype"""
899
+ """Adapt fill values in the range [0.0, 1.0] to the value range of the dtype"""
900
900
if value is None :
901
901
return value
902
902
@@ -907,9 +907,7 @@ def _adapt_fill(self, value, *, dtype):
907
907
elif isinstance (value , (list , tuple )):
908
908
return type (value )(type (v )(v * max_value ) for v in value )
909
909
else :
910
- raise pytest .UsageError (
911
- f"`fill` should be an int or float, or a list or tuple of the former, but got { value } "
912
- )
910
+ raise ValueError (f"fill should be an int or float, or a list or tuple of the former, but got '{ value } '" )
913
911
914
912
_EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict (
915
913
# float, int
@@ -926,7 +924,7 @@ def _adapt_fill(self, value, *, dtype):
926
924
# two-list of float, two-list of int, two-tuple of float, two-tuple of int
927
925
center = [None , [1.2 , 4.9 ], [- 3 , 1 ], (2.5 , - 4.7 ), (3 , 2 )],
928
926
)
929
- # The special case for `" shear"` makes sure we pick a value that is supported while JIT scripting
927
+ # The special case for shear makes sure we pick a value that is supported while JIT scripting
930
928
_MINIMAL_AFFINE_KWARGS = {
931
929
k : vs [0 ] if k != "shear" else next (v for v in vs if isinstance (v , list ))
932
930
for k , vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS .items ()
@@ -1110,14 +1108,14 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
1110
1108
)
1111
1109
1112
1110
mae = (actual .float () - expected .float ()).abs ().mean ()
1113
- assert mae < 5
1111
+ assert mae < 2 if interpolation is transforms . InterpolationMode . NEAREST else 8
1114
1112
1115
1113
@pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
1116
1114
@pytest .mark .parametrize (
1117
1115
"interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
1118
1116
)
1119
1117
@pytest .mark .parametrize ("fill" , _CORRECTNESS_FILL )
1120
- @pytest .mark .parametrize ("seed" , list (range (10 )))
1118
+ @pytest .mark .parametrize ("seed" , list (range (5 )))
1121
1119
def test_transform_image_correctness (self , center , interpolation , fill , seed ):
1122
1120
image = self ._make_input (torch .Tensor , dtype = torch .uint8 , device = "cpu" )
1123
1121
@@ -1127,18 +1125,14 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed):
1127
1125
** self ._CORRECTNESS_TRANSFORM_AFFINE_RANGES , center = center , interpolation = interpolation , fill = fill
1128
1126
)
1129
1127
1130
- torch .manual_seed (seed )
1131
- params = transform ._get_params ([image ])
1132
-
1133
1128
torch .manual_seed (seed )
1134
1129
actual = transform (image )
1135
1130
1136
- expected = F .to_image_tensor (
1137
- F .affine (F .to_image_pil (image ), ** params , center = center , interpolation = interpolation , fill = fill )
1138
- )
1131
+ torch .manual_seed (seed )
1132
+ expected = F .to_image_tensor (transform (F .to_image_pil (image )))
1139
1133
1140
1134
mae = (actual .float () - expected .float ()).abs ().mean ()
1141
- assert mae < 7
1135
+ assert mae < 2 if interpolation is transforms . InterpolationMode . NEAREST else 8
1142
1136
1143
1137
def _compute_affine_matrix (self , * , angle , translate , scale , shear , center ):
1144
1138
rot = math .radians (angle )
@@ -1210,7 +1204,7 @@ def test_functional_bounding_box_correctness(self, format, angle, translate, sca
1210
1204
1211
1205
@pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1212
1206
@pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
1213
- @pytest .mark .parametrize ("seed" , list (range (10 )))
1207
+ @pytest .mark .parametrize ("seed" , list (range (5 )))
1214
1208
def test_transform_bounding_box_correctness (self , format , center , seed ):
1215
1209
bounding_box = self ._make_input (datapoints .BoundingBox , format = format )
1216
1210
0 commit comments