Skip to content

Commit 33a47cb

Browse files
datumboxNicolasHugpmeier
authored andcommitted
[fbsync] Update transforms for PIL deprecation (#5898)
Summary: * Update transforms for PIL deprecation * Changes agreed at #5898 * black, sort constants, version check * Format tests * Square brackets * Update torchvision/transforms/_pil_constants.py Reviewed By: YosuaMichael Differential Revision: D36281595 fbshipit-source-id: 677b0cb42dd23589197b757291d35a4b5f603c58 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 309483c commit 33a47cb

File tree

7 files changed

+53
-25
lines changed

7 files changed

+53
-25
lines changed

test/test_onnx.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -412,12 +412,13 @@ def forward(self_module, images, features):
412412
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
413413
import os
414414

415+
import torchvision.transforms._pil_constants as _pil_constants
415416
from PIL import Image
416417
from torchvision.transforms import functional as F
417418

418419
data_dir = os.path.join(os.path.dirname(__file__), "assets")
419420
path = os.path.join(data_dir, *rel_path.split("/"))
420-
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
421+
image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
421422

422423
return F.convert_image_dtype(F.pil_to_tensor(image))
423424

test/test_transforms.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010
import torchvision.transforms as transforms
11+
import torchvision.transforms._pil_constants as _pil_constants
1112
import torchvision.transforms.functional as F
1213
import torchvision.transforms.functional_tensor as F_t
1314
from PIL import Image
@@ -173,7 +174,7 @@ def test_accimage_pil_to_tensor(self):
173174
def test_accimage_resize(self):
174175
trans = transforms.Compose(
175176
[
176-
transforms.Resize(256, interpolation=Image.LINEAR),
177+
transforms.Resize(256, interpolation=_pil_constants.LINEAR),
177178
transforms.PILToTensor(),
178179
transforms.ConvertImageDtype(dtype=torch.float),
179180
]

test/test_transforms_tensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66
import torch
7+
import torchvision.transforms._pil_constants as _pil_constants
78
from common_utils import (
89
get_tmp_dir,
910
int_dtypes,
@@ -15,7 +16,6 @@
1516
cpu_and_gpu,
1617
assert_equal,
1718
)
18-
from PIL import Image
1919
from torchvision import transforms as T
2020
from torchvision.transforms import InterpolationMode
2121
from torchvision.transforms import functional as F
@@ -771,13 +771,13 @@ def shear(pil_img, level, mode, resample):
771771
matrix = (1, level, 0, 0, 1, 0)
772772
elif mode == "Y":
773773
matrix = (1, 0, 0, level, 1, 0)
774-
return pil_img.transform((image_size, image_size), Image.AFFINE, matrix, resample=resample)
774+
return pil_img.transform((image_size, image_size), _pil_constants.AFFINE, matrix, resample=resample)
775775

776776
t_img, pil_img = _create_data(image_size, image_size)
777777

778778
resample_pil = {
779-
F.InterpolationMode.NEAREST: Image.NEAREST,
780-
F.InterpolationMode.BILINEAR: Image.BILINEAR,
779+
F.InterpolationMode.NEAREST: _pil_constants.NEAREST,
780+
F.InterpolationMode.BILINEAR: _pil_constants.BILINEAR,
781781
}[interpolation]
782782

783783
level = 0.3
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import PIL
2+
from PIL import Image
3+
4+
# See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations
5+
# TODO: Remove this file once PIL minimal version is >= 9.1
6+
7+
if tuple(int(part) for part in PIL.__version__.split(".")) >= (9, 1):
8+
BICUBIC = Image.Resampling.BICUBIC
9+
BILINEAR = Image.Resampling.BILINEAR
10+
LINEAR = Image.Resampling.BILINEAR
11+
NEAREST = Image.Resampling.NEAREST
12+
13+
AFFINE = Image.Transform.AFFINE
14+
FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT
15+
FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM
16+
PERSPECTIVE = Image.Transform.PERSPECTIVE
17+
else:
18+
BICUBIC = Image.BICUBIC
19+
BILINEAR = Image.BILINEAR
20+
NEAREST = Image.NEAREST
21+
LINEAR = Image.LINEAR
22+
23+
AFFINE = Image.AFFINE
24+
FLIP_LEFT_RIGHT = Image.FLIP_LEFT_RIGHT
25+
FLIP_TOP_BOTTOM = Image.FLIP_TOP_BOTTOM
26+
PERSPECTIVE = Image.PERSPECTIVE

torchvision/transforms/functional.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def resize(
392392
:class:`torchvision.transforms.InterpolationMode`.
393393
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
394394
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
395-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
395+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
396396
max_size (int, optional): The maximum allowed for the longer edge of
397397
the resized image: if the longer edge of the image is greater
398398
than ``max_size`` after being resized according to ``size``, then
@@ -572,7 +572,7 @@ def resized_crop(
572572
:class:`torchvision.transforms.InterpolationMode`.
573573
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
574574
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
575-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
575+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
576576
577577
Returns:
578578
PIL Image or Tensor: Cropped image.
@@ -652,7 +652,7 @@ def perspective(
652652
interpolation (InterpolationMode): Desired interpolation enum defined by
653653
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
654654
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
655-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
655+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
656656
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
657657
image. If given a number, the value is used for all bands respectively.
658658
@@ -1012,7 +1012,7 @@ def rotate(
10121012
interpolation (InterpolationMode): Desired interpolation enum defined by
10131013
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
10141014
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1015-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1015+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
10161016
expand (bool, optional): Optional expansion flag.
10171017
If true, expands the output image to make it large enough to hold the entire rotated image.
10181018
If false or omitted, make the output image the same size as the input image.
@@ -1105,7 +1105,7 @@ def affine(
11051105
interpolation (InterpolationMode): Desired interpolation enum defined by
11061106
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
11071107
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1108-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1108+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
11091109
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
11101110
image. If given a number, the value is used for all bands respectively.
11111111

torchvision/transforms/functional_pil.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import accimage
1111
except ImportError:
1212
accimage = None
13+
from . import _pil_constants
1314

1415

1516
@torch.jit.unused
@@ -54,15 +55,15 @@ def hflip(img: Image.Image) -> Image.Image:
5455
if not _is_pil_image(img):
5556
raise TypeError(f"img should be PIL Image. Got {type(img)}")
5657

57-
return img.transpose(Image.FLIP_LEFT_RIGHT)
58+
return img.transpose(_pil_constants.FLIP_LEFT_RIGHT)
5859

5960

6061
@torch.jit.unused
6162
def vflip(img: Image.Image) -> Image.Image:
6263
if not _is_pil_image(img):
6364
raise TypeError(f"img should be PIL Image. Got {type(img)}")
6465

65-
return img.transpose(Image.FLIP_TOP_BOTTOM)
66+
return img.transpose(_pil_constants.FLIP_TOP_BOTTOM)
6667

6768

6869
@torch.jit.unused
@@ -240,7 +241,7 @@ def crop(
240241
def resize(
241242
img: Image.Image,
242243
size: Union[Sequence[int], int],
243-
interpolation: int = Image.BILINEAR,
244+
interpolation: int = _pil_constants.BILINEAR,
244245
max_size: Optional[int] = None,
245246
) -> Image.Image:
246247

@@ -314,7 +315,7 @@ def _parse_fill(
314315
def affine(
315316
img: Image.Image,
316317
matrix: List[float],
317-
interpolation: int = Image.NEAREST,
318+
interpolation: int = _pil_constants.NEAREST,
318319
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
319320
) -> Image.Image:
320321

@@ -323,14 +324,14 @@ def affine(
323324

324325
output_size = img.size
325326
opts = _parse_fill(fill, img)
326-
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
327+
return img.transform(output_size, _pil_constants.AFFINE, matrix, interpolation, **opts)
327328

328329

329330
@torch.jit.unused
330331
def rotate(
331332
img: Image.Image,
332333
angle: float,
333-
interpolation: int = Image.NEAREST,
334+
interpolation: int = _pil_constants.NEAREST,
334335
expand: bool = False,
335336
center: Optional[Tuple[int, int]] = None,
336337
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
@@ -347,7 +348,7 @@ def rotate(
347348
def perspective(
348349
img: Image.Image,
349350
perspective_coeffs: float,
350-
interpolation: int = Image.BICUBIC,
351+
interpolation: int = _pil_constants.BICUBIC,
351352
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
352353
) -> Image.Image:
353354

@@ -356,7 +357,7 @@ def perspective(
356357

357358
opts = _parse_fill(fill, img)
358359

359-
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
360+
return img.transform(img.size, _pil_constants.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
360361

361362

362363
@torch.jit.unused

torchvision/transforms/transforms.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from . import functional as F
1818
from .functional import InterpolationMode, _interpolation_modes_from_int
1919

20-
2120
__all__ = [
2221
"Compose",
2322
"ToTensor",
@@ -298,7 +297,7 @@ class Resize(torch.nn.Module):
298297
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299298
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
300299
``InterpolationMode.BICUBIC`` are supported.
301-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
300+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
302301
max_size (int, optional): The maximum allowed for the longer edge of
303302
the resized image: if the longer edge of the image is greater
304303
than ``max_size`` after being resized according to ``size``, then
@@ -755,7 +754,7 @@ class RandomPerspective(torch.nn.Module):
755754
interpolation (InterpolationMode): Desired interpolation enum defined by
756755
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
757756
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
758-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
757+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
759758
fill (sequence or number): Pixel fill value for the area outside the transformed
760759
image. Default is ``0``. If given a number, the value is used for all bands respectively.
761760
"""
@@ -869,7 +868,7 @@ class RandomResizedCrop(torch.nn.Module):
869868
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
870869
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
871870
``InterpolationMode.BICUBIC`` are supported.
872-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
871+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
873872
874873
"""
875874

@@ -1268,7 +1267,7 @@ class RandomRotation(torch.nn.Module):
12681267
interpolation (InterpolationMode): Desired interpolation enum defined by
12691268
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
12701269
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1271-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1270+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
12721271
expand (bool, optional): Optional expansion flag.
12731272
If true, expands the output to make it large enough to hold the entire rotated image.
12741273
If false or omitted, make the output image the same size as the input image.
@@ -1389,7 +1388,7 @@ class RandomAffine(torch.nn.Module):
13891388
interpolation (InterpolationMode): Desired interpolation enum defined by
13901389
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
13911390
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1392-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1391+
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still acceptable.
13931392
fill (sequence or number): Pixel fill value for the area outside the transformed
13941393
image. Default is ``0``. If given a number, the value is used for all bands respectively.
13951394
fillcolor (sequence or number, optional):

0 commit comments

Comments
 (0)