34
34
DEFAULT_MAKE_IMAGES_KWARGS = dict (color_spaces = ["RGB" ], extra_dims = [(4 ,)])
35
35
36
36
37
+ class NotScriptableArgsKwargs (ArgsKwargs ):
38
+ """
39
+ This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
40
+ thus will be tested there, but will be skipped by the JIT tests.
41
+ """
42
+
43
+ pass
44
+
45
+
37
46
class ConsistencyConfig :
38
47
def __init__ (
39
48
self ,
@@ -73,7 +82,7 @@ def __init__(
73
82
prototype_transforms .Resize ,
74
83
legacy_transforms .Resize ,
75
84
[
76
- ArgsKwargs (32 ),
85
+ NotScriptableArgsKwargs (32 ),
77
86
ArgsKwargs ([32 ]),
78
87
ArgsKwargs ((32 , 29 )),
79
88
ArgsKwargs ((31 , 28 ), interpolation = prototype_transforms .InterpolationMode .NEAREST ),
@@ -84,8 +93,10 @@ def __init__(
84
93
# ArgsKwargs((30, 27), interpolation=0),
85
94
# ArgsKwargs((35, 29), interpolation=2),
86
95
# ArgsKwargs((34, 25), interpolation=3),
87
- ArgsKwargs (31 , max_size = 32 ),
88
- ArgsKwargs (30 , max_size = 100 ),
96
+ NotScriptableArgsKwargs (31 , max_size = 32 ),
97
+ ArgsKwargs ([31 ], max_size = 32 ),
98
+ NotScriptableArgsKwargs (30 , max_size = 100 ),
99
+ ArgsKwargs ([31 ], max_size = 32 ),
89
100
ArgsKwargs ((29 , 32 ), antialias = False ),
90
101
ArgsKwargs ((28 , 31 ), antialias = True ),
91
102
],
@@ -121,14 +132,15 @@ def __init__(
121
132
prototype_transforms .Pad ,
122
133
legacy_transforms .Pad ,
123
134
[
124
- ArgsKwargs (3 ),
135
+ NotScriptableArgsKwargs (3 ),
125
136
ArgsKwargs ([3 ]),
126
137
ArgsKwargs ([2 , 3 ]),
127
138
ArgsKwargs ([3 , 2 , 1 , 4 ]),
128
- ArgsKwargs (5 , fill = 1 , padding_mode = "constant" ),
129
- ArgsKwargs (5 , padding_mode = "edge" ),
130
- ArgsKwargs (5 , padding_mode = "reflect" ),
131
- ArgsKwargs (5 , padding_mode = "symmetric" ),
139
+ NotScriptableArgsKwargs (5 , fill = 1 , padding_mode = "constant" ),
140
+ ArgsKwargs ([5 ], fill = 1 , padding_mode = "constant" ),
141
+ NotScriptableArgsKwargs (5 , padding_mode = "edge" ),
142
+ NotScriptableArgsKwargs (5 , padding_mode = "reflect" ),
143
+ NotScriptableArgsKwargs (5 , padding_mode = "symmetric" ),
132
144
],
133
145
),
134
146
ConsistencyConfig (
@@ -170,7 +182,7 @@ def __init__(
170
182
ConsistencyConfig (
171
183
prototype_transforms .ToPILImage ,
172
184
legacy_transforms .ToPILImage ,
173
- [ArgsKwargs ()],
185
+ [NotScriptableArgsKwargs ()],
174
186
make_images_kwargs = dict (
175
187
color_spaces = [
176
188
"GRAY" ,
@@ -186,7 +198,7 @@ def __init__(
186
198
prototype_transforms .Lambda ,
187
199
legacy_transforms .Lambda ,
188
200
[
189
- ArgsKwargs (lambda image : image / 2 ),
201
+ NotScriptableArgsKwargs (lambda image : image / 2 ),
190
202
],
191
203
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
192
204
# images given that the transform does nothing but call it anyway.
@@ -380,14 +392,15 @@ def __init__(
380
392
[
381
393
ArgsKwargs (12 ),
382
394
ArgsKwargs ((15 , 17 )),
383
- ArgsKwargs (11 , padding = 1 ),
395
+ NotScriptableArgsKwargs (11 , padding = 1 ),
396
+ ArgsKwargs (11 , padding = [1 ]),
384
397
ArgsKwargs ((8 , 13 ), padding = (2 , 3 )),
385
398
ArgsKwargs ((14 , 9 ), padding = (0 , 2 , 1 , 0 )),
386
399
ArgsKwargs (36 , pad_if_needed = True ),
387
400
ArgsKwargs ((7 , 8 ), fill = 1 ),
388
- ArgsKwargs (5 , fill = (1 , 2 , 3 )),
401
+ NotScriptableArgsKwargs (5 , fill = (1 , 2 , 3 )),
389
402
ArgsKwargs (12 ),
390
- ArgsKwargs (15 , padding = 2 , padding_mode = "edge" ),
403
+ NotScriptableArgsKwargs (15 , padding = 2 , padding_mode = "edge" ),
391
404
ArgsKwargs (17 , padding = (1 , 0 ), padding_mode = "reflect" ),
392
405
ArgsKwargs (8 , padding = (3 , 0 , 0 , 1 ), padding_mode = "symmetric" ),
393
406
],
@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
642
655
)
643
656
644
657
658
+ @pytest .mark .parametrize (
659
+ ("config" , "args_kwargs" ),
660
+ [
661
+ pytest .param (
662
+ config , args_kwargs , id = f"{ config .legacy_cls .__name__ } -{ idx :0{len (str (len (config .args_kwargs )))}d} "
663
+ )
664
+ for config in CONSISTENCY_CONFIGS
665
+ for idx , args_kwargs in enumerate (config .args_kwargs )
666
+ if not isinstance (args_kwargs , NotScriptableArgsKwargs )
667
+ ],
668
+ )
669
+ def test_jit_consistency (config , args_kwargs ):
670
+ args , kwargs = args_kwargs
671
+
672
+ prototype_transform_eager = config .prototype_cls (* args , ** kwargs )
673
+ legacy_transform_eager = config .legacy_cls (* args , ** kwargs )
674
+
675
+ legacy_transform_scripted = torch .jit .script (legacy_transform_eager )
676
+ prototype_transform_scripted = torch .jit .script (prototype_transform_eager )
677
+
678
+ for image in make_images (** config .make_images_kwargs ):
679
+ image = image .as_subclass (torch .Tensor )
680
+
681
+ torch .manual_seed (0 )
682
+ output_legacy_scripted = legacy_transform_scripted (image )
683
+
684
+ torch .manual_seed (0 )
685
+ output_prototype_scripted = prototype_transform_scripted (image )
686
+
687
+ assert_close (output_prototype_scripted , output_legacy_scripted , ** config .closeness_kwargs )
688
+
689
+
645
690
class TestContainerTransforms :
646
691
"""
647
692
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
0 commit comments