Skip to content

Added typing annotations to transforms/functional_pil #4234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 18, 2021
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b2f6615
fix
oke-aditya May 20, 2021
4fb038d
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 20, 2021
deda5d7
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
5490821
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
4cfc220
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
6306746
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 24, 2021
e8c93cf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
6871ccc
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
80060bf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 29, 2021
3b5f0ca
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 30, 2021
0457e6d
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 30, 2021
fa9eb08
add functional PIL typings
oke-aditya Jul 30, 2021
1e6dd9e
fix types
oke-aditya Jul 30, 2021
18f28e7
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 16, 2021
9d294f8
fix types
oke-aditya Aug 16, 2021
6bfd004
fix a small one
oke-aditya Aug 16, 2021
0dce853
small fix
oke-aditya Aug 16, 2021
b6e1638
fix type
oke-aditya Aug 17, 2021
3de8a93
fix interpolation types
oke-aditya Aug 17, 2021
6c94b47
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 17, 2021
b3cb4c4
Merge branch 'master' into add_typing4
datumbox Aug 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, List, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -34,23 +34,23 @@ def _get_image_num_channels(img: Any) -> int:


@torch.jit.unused
def hflip(img):
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)


@torch.jit.unused
def vflip(img):
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)


@torch.jit.unused
def adjust_brightness(img, brightness_factor):
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor):


@torch.jit.unused
def adjust_contrast(img, contrast_factor):
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor):


@torch.jit.unused
def adjust_saturation(img, saturation_factor):
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor):


@torch.jit.unused
def adjust_hue(img, hue_factor):
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

Expand All @@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor):


@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1):


@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is defined as int on the tensor equivalent method, though it supports both. Do we want them aligned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want float in tensor equivalent method to align this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no strong opinion here. I leave it up to you.

padding_mode: str = "constant",
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

Expand Down Expand Up @@ -196,15 +207,28 @@ def pad(img, padding, fill=0, padding_mode="constant"):


@torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((left, top, left + width, top + height))


@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
def resize(
img: Image.Image,
size: Union[Sequence[int], int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor API uses List instead of Sequence. Do we want them aligned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we need to have Sequence in tensor, to align. But probably JIT does not support it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, torch.jit does not support Sequence. Since we also have a Union here, we can't align them anyway. Thus, I think it is ok to go with Sequence.

interpolation: int = Image.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
Expand Down Expand Up @@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None):


@torch.jit.unused
def _parse_fill(fill, img, name="fillcolor"):
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:

# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
Expand All @@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"):


@torch.jit.unused
def affine(img, matrix, interpolation=0, fill=None):
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI strangely the tensor API defines this as float as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the tensor API defines it to be only List[float] and not all the three, again something which isn't aligned or probably has an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the union that is not supported. I just highlight that here the type (float vs int) is actually aligned. Which is weird cause we don't seem to keep things aligned. I probably agree that it's the tensor that needs changing though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just verified Sequence is not supported by JIT. And yes Union too is not supported.

https://pytorch.org/docs/stable/jit_language_reference.html#supported-type

) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None):


@torch.jit.unused
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

Expand All @@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):


@torch.jit.unused
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None):
def perspective(
img: Image.Image,
perspective_coeffs: float,
interpolation: int = Image.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)


@torch.jit.unused
def to_grayscale(img, num_output_channels):
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels):


@torch.jit.unused
def invert(img):
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)


@torch.jit.unused
def posterize(img, bits):
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)


@torch.jit.unused
def solarize(img, threshold):
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)


@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor):


@torch.jit.unused
def autocontrast(img):
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)


@torch.jit.unused
def equalize(img):
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)