@@ -963,57 +963,135 @@ def _adapt_fill(self, value, *, dtype):
963
963
k : next (v for v in vs if v is not None ) for k , vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES .items ()
964
964
}
965
965
966
+ def _check_kernel (self , kernel , input , * args , ** kwargs ):
967
+ kwargs_ = self ._MINIMAL_AFFINE_KWARGS .copy ()
968
+ kwargs_ .update (kwargs )
969
+ check_kernel (kernel , input , * args , ** kwargs_ )
970
+
966
971
@pytest .mark .parametrize ("angle" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["angle" ])
972
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
973
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
974
+ def test_kernel_image_tensor_angle (self , angle , dtype , device ):
975
+ self ._check_kernel (
976
+ F .affine_image_tensor ,
977
+ self ._make_input (torch .Tensor , dtype = dtype , device = device ),
978
+ angle = angle ,
979
+ )
980
+
967
981
@pytest .mark .parametrize ("translate" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["translate" ])
968
- @pytest .mark .parametrize ("scale" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["scale" ])
982
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
983
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
984
+ def test_kernel_image_tensor_translate (self , translate , dtype , device ):
985
+ self ._check_kernel (
986
+ F .affine_image_tensor ,
987
+ self ._make_input (torch .Tensor , dtype = dtype , device = device ),
988
+ translate = translate ,
989
+ )
990
+
969
991
@pytest .mark .parametrize ("shear" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["shear" ])
992
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
993
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
994
+ def test_kernel_image_tensor_shear (self , shear , dtype , device ):
995
+ self ._check_kernel (
996
+ F .affine_image_tensor ,
997
+ self ._make_input (torch .Tensor , dtype = dtype , device = device ),
998
+ shear = shear ,
999
+ check_scripted_vs_eager = not isinstance (shear , (int , float )),
1000
+ )
1001
+
970
1002
@pytest .mark .parametrize ("center" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["center" ])
1003
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
1004
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1005
+ def test_kernel_image_tensor_center (self , center , dtype , device ):
1006
+ self ._check_kernel (
1007
+ F .affine_image_tensor ,
1008
+ self ._make_input (torch .Tensor , dtype = dtype , device = device ),
1009
+ center = center ,
1010
+ )
1011
+
971
1012
@pytest .mark .parametrize (
972
1013
"interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
973
1014
)
974
- @pytest .mark .parametrize ("fill" , _EXHAUSTIVE_TYPE_FILLS )
975
1015
@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
976
1016
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
977
- def test_kernel_image_tensor (self , angle , translate , scale , shear , center , interpolation , fill , dtype , device ):
978
- check_kernel (
1017
+ def test_kernel_image_tensor_interpolation (self , interpolation , dtype , device ):
1018
+ self . _check_kernel (
979
1019
F .affine_image_tensor ,
980
1020
self ._make_input (torch .Tensor , dtype = dtype , device = device ),
981
- angle = angle ,
982
- translate = translate ,
983
- scale = scale ,
984
- shear = shear ,
985
- center = center ,
986
1021
interpolation = interpolation ,
987
- fill = self ._adapt_fill (fill , dtype = dtype ),
988
- check_scripted_vs_eager = not (isinstance (shear , (int , float )) or isinstance (fill , (int , float ))),
989
1022
check_cuda_vs_cpu = dict (atol = 1 , rtol = 0 )
990
1023
if dtype is torch .uint8 and interpolation is transforms .InterpolationMode .BILINEAR
991
1024
else True ,
992
1025
)
993
1026
994
- @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1027
+ @pytest .mark .parametrize ("fill" , _EXHAUSTIVE_TYPE_FILLS )
1028
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
1029
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1030
+ def test_kernel_image_tensor_fill (self , fill , dtype , device ):
1031
+ self ._check_kernel (
1032
+ F .affine_image_tensor ,
1033
+ self ._make_input (torch .Tensor , dtype = dtype , device = device ),
1034
+ fill = self ._adapt_fill (fill , dtype = dtype ),
1035
+ check_scripted_vs_eager = not isinstance (fill , (int , float )),
1036
+ )
1037
+
995
1038
@pytest .mark .parametrize ("angle" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["angle" ])
996
- @pytest .mark .parametrize ("translate" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["translate" ])
997
- @pytest .mark .parametrize ("scale" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["scale" ])
998
- @pytest .mark .parametrize ("shear" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["shear" ])
999
- @pytest .mark .parametrize ("center" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["center" ])
1039
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1000
1040
@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
1001
1041
@pytest .mark .parametrize ("device" , cpu_and_cuda ())
1002
- def test_kernel_bounding_box (self , format , angle , translate , scale , shear , center , dtype , device ):
1003
- bounding_box = self ._make_input (datapoints .BoundingBox , dtype = dtype , device = device , format = format )
1004
- check_kernel (
1042
+ def test_kernel_bounding_box_angle (self , angle , format , dtype , device ):
1043
+ bounding_box = self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device )
1044
+ self . _check_kernel (
1005
1045
F .affine_bounding_box ,
1006
- bounding_box ,
1046
+ self . _make_input ( datapoints . BoundingBox , format = format , dtype = dtype , device = device ) ,
1007
1047
format = format ,
1008
1048
spatial_size = bounding_box .spatial_size ,
1009
1049
angle = angle ,
1050
+ )
1051
+
1052
+ @pytest .mark .parametrize ("translate" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["translate" ])
1053
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1054
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
1055
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1056
+ def test_kernel_bounding_box_translate (self , translate , format , dtype , device ):
1057
+ bounding_box = self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device )
1058
+ self ._check_kernel (
1059
+ F .affine_bounding_box ,
1060
+ self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device ),
1061
+ format = format ,
1062
+ spatial_size = bounding_box .spatial_size ,
1010
1063
translate = translate ,
1011
- scale = scale ,
1064
+ )
1065
+
1066
+ @pytest .mark .parametrize ("shear" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["shear" ])
1067
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1068
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
1069
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1070
+ def test_kernel_bounding_box_shear (self , shear , format , dtype , device ):
1071
+ bounding_box = self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device )
1072
+ self ._check_kernel (
1073
+ F .affine_bounding_box ,
1074
+ self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device ),
1075
+ format = format ,
1076
+ spatial_size = bounding_box .spatial_size ,
1012
1077
shear = shear ,
1013
- center = center ,
1014
1078
check_scripted_vs_eager = not isinstance (shear , (int , float )),
1015
1079
)
1016
1080
1081
+ @pytest .mark .parametrize ("center" , _EXHAUSTIVE_TYPE_AFFINE_KWARGS ["center" ])
1082
+ @pytest .mark .parametrize ("format" , list (datapoints .BoundingBoxFormat ))
1083
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
1084
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1085
+ def test_kernel_bounding_box_center (self , center , format , dtype , device ):
1086
+ bounding_box = self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device )
1087
+ self ._check_kernel (
1088
+ F .affine_bounding_box ,
1089
+ self ._make_input (datapoints .BoundingBox , format = format , dtype = dtype , device = device ),
1090
+ format = format ,
1091
+ spatial_size = bounding_box .spatial_size ,
1092
+ center = center ,
1093
+ )
1094
+
1017
1095
@pytest .mark .parametrize ("mask_type" , ["segmentation" , "detection" ])
1018
1096
def test_kernel_mask (self , mask_type ):
1019
1097
check_kernel (
0 commit comments