Skip to content

Commit 7cf0f4c

Browse files
authored
make transforms v2 JIT scriptable (#7135)
1 parent 170160a commit 7cf0f4c

File tree

8 files changed

+206
-17
lines changed

8 files changed

+206
-17
lines changed

test/test_prototype_transforms_consistency.py

+58-13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@
3434
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
3535

3636

37+
class NotScriptableArgsKwargs(ArgsKwargs):
38+
"""
39+
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
40+
thus will be tested there, but will be skipped by the JIT tests.
41+
"""
42+
43+
pass
44+
45+
3746
class ConsistencyConfig:
3847
def __init__(
3948
self,
@@ -73,7 +82,7 @@ def __init__(
7382
prototype_transforms.Resize,
7483
legacy_transforms.Resize,
7584
[
76-
ArgsKwargs(32),
85+
NotScriptableArgsKwargs(32),
7786
ArgsKwargs([32]),
7887
ArgsKwargs((32, 29)),
7988
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
@@ -84,8 +93,10 @@ def __init__(
8493
# ArgsKwargs((30, 27), interpolation=0),
8594
# ArgsKwargs((35, 29), interpolation=2),
8695
# ArgsKwargs((34, 25), interpolation=3),
87-
ArgsKwargs(31, max_size=32),
88-
ArgsKwargs(30, max_size=100),
96+
NotScriptableArgsKwargs(31, max_size=32),
97+
ArgsKwargs([31], max_size=32),
98+
NotScriptableArgsKwargs(30, max_size=100),
99+
ArgsKwargs([31], max_size=32),
89100
ArgsKwargs((29, 32), antialias=False),
90101
ArgsKwargs((28, 31), antialias=True),
91102
],
@@ -121,14 +132,15 @@ def __init__(
121132
prototype_transforms.Pad,
122133
legacy_transforms.Pad,
123134
[
124-
ArgsKwargs(3),
135+
NotScriptableArgsKwargs(3),
125136
ArgsKwargs([3]),
126137
ArgsKwargs([2, 3]),
127138
ArgsKwargs([3, 2, 1, 4]),
128-
ArgsKwargs(5, fill=1, padding_mode="constant"),
129-
ArgsKwargs(5, padding_mode="edge"),
130-
ArgsKwargs(5, padding_mode="reflect"),
131-
ArgsKwargs(5, padding_mode="symmetric"),
139+
NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
140+
ArgsKwargs([5], fill=1, padding_mode="constant"),
141+
NotScriptableArgsKwargs(5, padding_mode="edge"),
142+
NotScriptableArgsKwargs(5, padding_mode="reflect"),
143+
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
132144
],
133145
),
134146
ConsistencyConfig(
@@ -170,7 +182,7 @@ def __init__(
170182
ConsistencyConfig(
171183
prototype_transforms.ToPILImage,
172184
legacy_transforms.ToPILImage,
173-
[ArgsKwargs()],
185+
[NotScriptableArgsKwargs()],
174186
make_images_kwargs=dict(
175187
color_spaces=[
176188
"GRAY",
@@ -186,7 +198,7 @@ def __init__(
186198
prototype_transforms.Lambda,
187199
legacy_transforms.Lambda,
188200
[
189-
ArgsKwargs(lambda image: image / 2),
201+
NotScriptableArgsKwargs(lambda image: image / 2),
190202
],
191203
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
192204
# images given that the transform does nothing but call it anyway.
@@ -380,14 +392,15 @@ def __init__(
380392
[
381393
ArgsKwargs(12),
382394
ArgsKwargs((15, 17)),
383-
ArgsKwargs(11, padding=1),
395+
NotScriptableArgsKwargs(11, padding=1),
396+
ArgsKwargs(11, padding=[1]),
384397
ArgsKwargs((8, 13), padding=(2, 3)),
385398
ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
386399
ArgsKwargs(36, pad_if_needed=True),
387400
ArgsKwargs((7, 8), fill=1),
388-
ArgsKwargs(5, fill=(1, 2, 3)),
401+
NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
389402
ArgsKwargs(12),
390-
ArgsKwargs(15, padding=2, padding_mode="edge"),
403+
NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
391404
ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
392405
ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
393406
],
@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
642655
)
643656

644657

658+
@pytest.mark.parametrize(
659+
("config", "args_kwargs"),
660+
[
661+
pytest.param(
662+
config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
663+
)
664+
for config in CONSISTENCY_CONFIGS
665+
for idx, args_kwargs in enumerate(config.args_kwargs)
666+
if not isinstance(args_kwargs, NotScriptableArgsKwargs)
667+
],
668+
)
669+
def test_jit_consistency(config, args_kwargs):
670+
args, kwargs = args_kwargs
671+
672+
prototype_transform_eager = config.prototype_cls(*args, **kwargs)
673+
legacy_transform_eager = config.legacy_cls(*args, **kwargs)
674+
675+
legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
676+
prototype_transform_scripted = torch.jit.script(prototype_transform_eager)
677+
678+
for image in make_images(**config.make_images_kwargs):
679+
image = image.as_subclass(torch.Tensor)
680+
681+
torch.manual_seed(0)
682+
output_legacy_scripted = legacy_transform_scripted(image)
683+
684+
torch.manual_seed(0)
685+
output_prototype_scripted = prototype_transform_scripted(image)
686+
687+
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
688+
689+
645690
class TestContainerTransforms:
646691
"""
647692
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for

torchvision/prototype/transforms/_augment.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL.Image
77
import torch
88
from torch.utils._pytree import tree_flatten, tree_unflatten
9-
9+
from torchvision import transforms as _transforms
1010
from torchvision.ops import masks_to_boxes
1111
from torchvision.prototype import datapoints
1212
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
@@ -16,6 +16,14 @@
1616

1717

1818
class RandomErasing(_RandomApplyTransform):
19+
_v1_transform_cls = _transforms.RandomErasing
20+
21+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
22+
return dict(
23+
super()._extract_params_for_v1_transform(),
24+
value="random" if self.value is None else self.value,
25+
)
26+
1927
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
2028

2129
def __init__(

torchvision/prototype/transforms/_auto_augment.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8-
8+
from torchvision import transforms as _transforms
99
from torchvision.prototype import datapoints
1010
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
1111
from torchvision.prototype.transforms.functional._meta import get_spatial_size
@@ -161,6 +161,8 @@ def _apply_image_or_video_transform(
161161

162162

163163
class AutoAugment(_AutoAugmentBase):
164+
_v1_transform_cls = _transforms.AutoAugment
165+
164166
_AUGMENTATION_SPACE = {
165167
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
166168
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
@@ -315,6 +317,7 @@ def forward(self, *inputs: Any) -> Any:
315317

316318

317319
class RandAugment(_AutoAugmentBase):
320+
_v1_transform_cls = _transforms.RandAugment
318321
_AUGMENTATION_SPACE = {
319322
"Identity": (lambda num_bins, height, width: None, False),
320323
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
@@ -375,6 +378,7 @@ def forward(self, *inputs: Any) -> Any:
375378

376379

377380
class TrivialAugmentWide(_AutoAugmentBase):
381+
_v1_transform_cls = _transforms.TrivialAugmentWide
378382
_AUGMENTATION_SPACE = {
379383
"Identity": (lambda num_bins, height, width: None, False),
380384
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
@@ -425,6 +429,8 @@ def forward(self, *inputs: Any) -> Any:
425429

426430

427431
class AugMix(_AutoAugmentBase):
432+
_v1_transform_cls = _transforms.AugMix
433+
428434
_PARTIAL_AUGMENTATION_SPACE = {
429435
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
430436
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),

torchvision/prototype/transforms/_color.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import PIL.Image
55
import torch
6-
6+
from torchvision import transforms as _transforms
77
from torchvision.prototype import datapoints
88
from torchvision.prototype.transforms import functional as F, Transform
99

@@ -12,6 +12,8 @@
1212

1313

1414
class Grayscale(Transform):
15+
_v1_transform_cls = _transforms.Grayscale
16+
1517
_transformed_types = (
1618
datapoints.Image,
1719
PIL.Image.Image,
@@ -28,6 +30,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
2830

2931

3032
class RandomGrayscale(_RandomApplyTransform):
33+
_v1_transform_cls = _transforms.RandomGrayscale
34+
3135
_transformed_types = (
3236
datapoints.Image,
3337
PIL.Image.Image,
@@ -47,6 +51,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4751

4852

4953
class ColorJitter(Transform):
54+
_v1_transform_cls = _transforms.ColorJitter
55+
56+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
57+
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
58+
5059
def __init__(
5160
self,
5261
brightness: Optional[Union[float, Sequence[float]]] = None,
@@ -194,16 +203,22 @@ def _transform(
194203

195204

196205
class RandomEqualize(_RandomApplyTransform):
206+
_v1_transform_cls = _transforms.RandomEqualize
207+
197208
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
198209
return F.equalize(inpt)
199210

200211

201212
class RandomInvert(_RandomApplyTransform):
213+
_v1_transform_cls = _transforms.RandomInvert
214+
202215
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
203216
return F.invert(inpt)
204217

205218

206219
class RandomPosterize(_RandomApplyTransform):
220+
_v1_transform_cls = _transforms.RandomPosterize
221+
207222
def __init__(self, bits: int, p: float = 0.5) -> None:
208223
super().__init__(p=p)
209224
self.bits = bits
@@ -213,6 +228,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
213228

214229

215230
class RandomSolarize(_RandomApplyTransform):
231+
_v1_transform_cls = _transforms.RandomSolarize
232+
216233
def __init__(self, threshold: float, p: float = 0.5) -> None:
217234
super().__init__(p=p)
218235
self.threshold = threshold
@@ -222,11 +239,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
222239

223240

224241
class RandomAutocontrast(_RandomApplyTransform):
242+
_v1_transform_cls = _transforms.RandomAutocontrast
243+
225244
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
226245
return F.autocontrast(inpt)
227246

228247

229248
class RandomAdjustSharpness(_RandomApplyTransform):
249+
_v1_transform_cls = _transforms.RandomAdjustSharpness
250+
230251
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
231252
super().__init__(p=p)
232253
self.sharpness_factor = sharpness_factor

0 commit comments

Comments
 (0)