Skip to content

Commit 08cd318

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Support integer values for interpolation in the prototype transforms (#7248)
Reviewed By: vmoens Differential Revision: D44416580 fbshipit-source-id: 41b5fa458ba3a54b3f1e4787a289d3408bcd01e4
1 parent 31d5d9f commit 08cd318

12 files changed

+153
-96
lines changed

test/test_prototype_transforms.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1534,7 +1534,7 @@ def test__get_params(self, mocker):
15341534
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)
15351535

15361536
def test__transform(self, mocker):
1537-
interpolation_sentinel = mocker.MagicMock()
1537+
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
15381538
antialias_sentinel = mocker.MagicMock()
15391539

15401540
transform = transforms.ScaleJitter(
@@ -1581,7 +1581,7 @@ def test__get_params(self, min_size, max_size, mocker):
15811581
assert shorter in min_size
15821582

15831583
def test__transform(self, mocker):
1584-
interpolation_sentinel = mocker.MagicMock()
1584+
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
15851585
antialias_sentinel = mocker.MagicMock()
15861586

15871587
transform = transforms.RandomShortestSize(
@@ -1945,7 +1945,7 @@ def test__get_params(self):
19451945
assert min_size <= size < max_size
19461946

19471947
def test__transform(self, mocker):
1948-
interpolation_sentinel = mocker.MagicMock()
1948+
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
19491949
antialias_sentinel = mocker.MagicMock()
19501950

19511951
transform = transforms.RandomResize(

test/test_prototype_transforms_consistency.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def __init__(
8888
ArgsKwargs((32, 29)),
8989
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
9090
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
91+
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
92+
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
93+
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
9194
NotScriptableArgsKwargs(31, max_size=32),
9295
ArgsKwargs([31], max_size=32),
9396
NotScriptableArgsKwargs(30, max_size=100),
@@ -305,6 +308,8 @@ def __init__(
305308
ArgsKwargs(25, ratio=(0.5, 1.5)),
306309
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
307310
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
311+
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
312+
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
308313
ArgsKwargs((29, 32), antialias=False),
309314
ArgsKwargs((28, 31), antialias=True),
310315
],
@@ -352,6 +357,8 @@ def __init__(
352357
ArgsKwargs(sigma=(2.5, 3.9)),
353358
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
354359
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
360+
ArgsKwargs(interpolation=PIL.Image.NEAREST),
361+
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
355362
ArgsKwargs(fill=1),
356363
],
357364
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
@@ -386,6 +393,7 @@ def __init__(
386393
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
387394
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
388395
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
396+
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
389397
ArgsKwargs(degrees=30.0, fill=1),
390398
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
391399
ArgsKwargs(degrees=30.0, center=(0, 0)),
@@ -420,6 +428,7 @@ def __init__(
420428
ArgsKwargs(p=1),
421429
ArgsKwargs(p=1, distortion_scale=0.3),
422430
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
431+
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
423432
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
424433
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
425434
],
@@ -432,6 +441,7 @@ def __init__(
432441
ArgsKwargs(degrees=30.0),
433442
ArgsKwargs(degrees=(-20.0, 10.0)),
434443
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
444+
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
435445
ArgsKwargs(degrees=30.0, expand=True),
436446
ArgsKwargs(degrees=30.0, center=(0, 0)),
437447
ArgsKwargs(degrees=30.0, fill=1),
@@ -851,7 +861,11 @@ class TestAATransforms:
851861
)
852862
@pytest.mark.parametrize(
853863
"interpolation",
854-
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
864+
[
865+
prototype_transforms.InterpolationMode.NEAREST,
866+
prototype_transforms.InterpolationMode.BILINEAR,
867+
PIL.Image.NEAREST,
868+
],
855869
)
856870
def test_randaug(self, inpt, interpolation, mocker):
857871
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
@@ -889,7 +903,11 @@ def test_randaug(self, inpt, interpolation, mocker):
889903
)
890904
@pytest.mark.parametrize(
891905
"interpolation",
892-
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
906+
[
907+
prototype_transforms.InterpolationMode.NEAREST,
908+
prototype_transforms.InterpolationMode.BILINEAR,
909+
PIL.Image.NEAREST,
910+
],
893911
)
894912
def test_trivial_aug(self, inpt, interpolation, mocker):
895913
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
@@ -937,7 +955,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
937955
)
938956
@pytest.mark.parametrize(
939957
"interpolation",
940-
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
958+
[
959+
prototype_transforms.InterpolationMode.NEAREST,
960+
prototype_transforms.InterpolationMode.BILINEAR,
961+
PIL.Image.NEAREST,
962+
],
941963
)
942964
def test_augmix(self, inpt, interpolation, mocker):
943965
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
@@ -986,7 +1008,11 @@ def test_augmix(self, inpt, interpolation, mocker):
9861008
)
9871009
@pytest.mark.parametrize(
9881010
"interpolation",
989-
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
1011+
[
1012+
prototype_transforms.InterpolationMode.NEAREST,
1013+
prototype_transforms.InterpolationMode.BILINEAR,
1014+
PIL.Image.NEAREST,
1015+
],
9901016
)
9911017
def test_aa(self, inpt, interpolation):
9921018
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
@@ -1264,13 +1290,13 @@ def test_random_resize_eval(self, mocker):
12641290
(legacy_F.convert_image_dtype, {}),
12651291
(legacy_F.to_pil_image, {}),
12661292
(legacy_F.normalize, {}),
1267-
(legacy_F.resize, {}),
1293+
(legacy_F.resize, {"interpolation"}),
12681294
(legacy_F.pad, {"padding", "fill"}),
12691295
(legacy_F.crop, {}),
12701296
(legacy_F.center_crop, {}),
1271-
(legacy_F.resized_crop, {}),
1297+
(legacy_F.resized_crop, {"interpolation"}),
12721298
(legacy_F.hflip, {}),
1273-
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
1299+
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
12741300
(legacy_F.vflip, {}),
12751301
(legacy_F.five_crop, {}),
12761302
(legacy_F.ten_crop, {}),
@@ -1279,8 +1305,8 @@ def test_random_resize_eval(self, mocker):
12791305
(legacy_F.adjust_saturation, {}),
12801306
(legacy_F.adjust_hue, {}),
12811307
(legacy_F.adjust_gamma, {}),
1282-
(legacy_F.rotate, {"center", "fill"}),
1283-
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
1308+
(legacy_F.rotate, {"center", "fill", "interpolation"}),
1309+
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
12841310
(legacy_F.to_grayscale, {}),
12851311
(legacy_F.rgb_to_grayscale, {}),
12861312
(legacy_F.to_tensor, {}),
@@ -1292,7 +1318,7 @@ def test_random_resize_eval(self, mocker):
12921318
(legacy_F.adjust_sharpness, {}),
12931319
(legacy_F.autocontrast, {}),
12941320
(legacy_F.equalize, {}),
1295-
(legacy_F.elastic_transform, {"fill"}),
1321+
(legacy_F.elastic_transform, {"fill", "interpolation"}),
12961322
],
12971323
)
12981324
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):

torchvision/prototype/datapoints/_bounding_box.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def vertical_flip(self) -> BoundingBox:
7676
def resize( # type: ignore[override]
7777
self,
7878
size: List[int],
79-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
79+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
8080
max_size: Optional[int] = None,
8181
antialias: Optional[Union[str, bool]] = "warn",
8282
) -> BoundingBox:
@@ -107,7 +107,7 @@ def resized_crop(
107107
height: int,
108108
width: int,
109109
size: List[int],
110-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
110+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
111111
antialias: Optional[Union[str, bool]] = "warn",
112112
) -> BoundingBox:
113113
output, spatial_size = self._F.resized_crop_bounding_box(
@@ -133,7 +133,7 @@ def pad(
133133
def rotate(
134134
self,
135135
angle: float,
136-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
136+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
137137
expand: bool = False,
138138
center: Optional[List[float]] = None,
139139
fill: FillTypeJIT = None,
@@ -154,7 +154,7 @@ def affine(
154154
translate: List[float],
155155
scale: float,
156156
shear: List[float],
157-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
157+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
158158
fill: FillTypeJIT = None,
159159
center: Optional[List[float]] = None,
160160
) -> BoundingBox:
@@ -174,7 +174,7 @@ def perspective(
174174
self,
175175
startpoints: Optional[List[List[int]]],
176176
endpoints: Optional[List[List[int]]],
177-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
177+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
178178
fill: FillTypeJIT = None,
179179
coefficients: Optional[List[float]] = None,
180180
) -> BoundingBox:
@@ -191,7 +191,7 @@ def perspective(
191191
def elastic(
192192
self,
193193
displacement: torch.Tensor,
194-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
194+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
195195
fill: FillTypeJIT = None,
196196
) -> BoundingBox:
197197
output = self._F.elastic_bounding_box(

torchvision/prototype/datapoints/_datapoint.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def vertical_flip(self) -> Datapoint:
143143
def resize( # type: ignore[override]
144144
self,
145145
size: List[int],
146-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
146+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
147147
max_size: Optional[int] = None,
148148
antialias: Optional[Union[str, bool]] = "warn",
149149
) -> Datapoint:
@@ -162,7 +162,7 @@ def resized_crop(
162162
height: int,
163163
width: int,
164164
size: List[int],
165-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
165+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
166166
antialias: Optional[Union[str, bool]] = "warn",
167167
) -> Datapoint:
168168
return self
@@ -178,7 +178,7 @@ def pad(
178178
def rotate(
179179
self,
180180
angle: float,
181-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
181+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
182182
expand: bool = False,
183183
center: Optional[List[float]] = None,
184184
fill: FillTypeJIT = None,
@@ -191,7 +191,7 @@ def affine(
191191
translate: List[float],
192192
scale: float,
193193
shear: List[float],
194-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
194+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
195195
fill: FillTypeJIT = None,
196196
center: Optional[List[float]] = None,
197197
) -> Datapoint:
@@ -201,7 +201,7 @@ def perspective(
201201
self,
202202
startpoints: Optional[List[List[int]]],
203203
endpoints: Optional[List[List[int]]],
204-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
204+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
205205
fill: FillTypeJIT = None,
206206
coefficients: Optional[List[float]] = None,
207207
) -> Datapoint:
@@ -210,7 +210,7 @@ def perspective(
210210
def elastic(
211211
self,
212212
displacement: torch.Tensor,
213-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
213+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
214214
fill: FillTypeJIT = None,
215215
) -> Datapoint:
216216
return self

torchvision/prototype/datapoints/_image.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def vertical_flip(self) -> Image:
6262
def resize( # type: ignore[override]
6363
self,
6464
size: List[int],
65-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
65+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
6666
max_size: Optional[int] = None,
6767
antialias: Optional[Union[str, bool]] = "warn",
6868
) -> Image:
@@ -86,7 +86,7 @@ def resized_crop(
8686
height: int,
8787
width: int,
8888
size: List[int],
89-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
89+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
9090
antialias: Optional[Union[str, bool]] = "warn",
9191
) -> Image:
9292
output = self._F.resized_crop_image_tensor(
@@ -113,7 +113,7 @@ def pad(
113113
def rotate(
114114
self,
115115
angle: float,
116-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
116+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
117117
expand: bool = False,
118118
center: Optional[List[float]] = None,
119119
fill: FillTypeJIT = None,
@@ -129,7 +129,7 @@ def affine(
129129
translate: List[float],
130130
scale: float,
131131
shear: List[float],
132-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
132+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
133133
fill: FillTypeJIT = None,
134134
center: Optional[List[float]] = None,
135135
) -> Image:
@@ -149,7 +149,7 @@ def perspective(
149149
self,
150150
startpoints: Optional[List[List[int]]],
151151
endpoints: Optional[List[List[int]]],
152-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
152+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
153153
fill: FillTypeJIT = None,
154154
coefficients: Optional[List[float]] = None,
155155
) -> Image:
@@ -166,7 +166,7 @@ def perspective(
166166
def elastic(
167167
self,
168168
displacement: torch.Tensor,
169-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
169+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
170170
fill: FillTypeJIT = None,
171171
) -> Image:
172172
output = self._F.elastic_image_tensor(

torchvision/prototype/datapoints/_mask.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def vertical_flip(self) -> Mask:
5353
def resize( # type: ignore[override]
5454
self,
5555
size: List[int],
56-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
56+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
5757
max_size: Optional[int] = None,
5858
antialias: Optional[Union[str, bool]] = "warn",
5959
) -> Mask:
@@ -75,7 +75,7 @@ def resized_crop(
7575
height: int,
7676
width: int,
7777
size: List[int],
78-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
78+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
7979
antialias: Optional[Union[str, bool]] = "warn",
8080
) -> Mask:
8181
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
@@ -93,7 +93,7 @@ def pad(
9393
def rotate(
9494
self,
9595
angle: float,
96-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
96+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
9797
expand: bool = False,
9898
center: Optional[List[float]] = None,
9999
fill: FillTypeJIT = None,
@@ -107,7 +107,7 @@ def affine(
107107
translate: List[float],
108108
scale: float,
109109
shear: List[float],
110-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
110+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
111111
fill: FillTypeJIT = None,
112112
center: Optional[List[float]] = None,
113113
) -> Mask:
@@ -126,7 +126,7 @@ def perspective(
126126
self,
127127
startpoints: Optional[List[List[int]]],
128128
endpoints: Optional[List[List[int]]],
129-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
129+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
130130
fill: FillTypeJIT = None,
131131
coefficients: Optional[List[float]] = None,
132132
) -> Mask:
@@ -138,7 +138,7 @@ def perspective(
138138
def elastic(
139139
self,
140140
displacement: torch.Tensor,
141-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
141+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
142142
fill: FillTypeJIT = None,
143143
) -> Mask:
144144
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)

0 commit comments

Comments
 (0)