Skip to content

Add tests for prototype <-> legacy transforms consistency #6514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import itertools

import pytest
import torch.testing
from test_prototype_transforms_functional import make_images
from torchvision import transforms as legacy_transforms
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor

DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])


class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def __iter__(self):
yield self.args
yield self.kwargs

def __str__(self):
return ", ".join(
itertools.chain(
[repr(arg) for arg in self.args],
[f"{param}={repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)


class ConsistencyConfig:
def __init__(
self, prototype_cls, legacy_cls, transform_args_kwargs=None, make_images_kwargs=None, supports_pil=True
):
self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls
self.transform_args_kwargs = transform_args_kwargs or [((), dict())]
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil

def parametrization(self):
return [
pytest.param(
self.prototype_cls,
self.legacy_cls,
args_kwargs,
self.make_images_kwargs,
self.supports_pil,
id=f"{self.prototype_cls.__name__}({args_kwargs})",
)
for args_kwargs in self.transform_args_kwargs
]


CONSISTENCY_CONFIGS = [
ConsistencyConfig(
prototype_transforms.Normalize,
legacy_transforms.Normalize,
[
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
prototype_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs(32),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
],
),
ConsistencyConfig(
prototype_transforms.CenterCrop,
legacy_transforms.CenterCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
),
]


@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs

try:
legacy = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc

try:
prototype = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc

for image in make_images(**make_images_kwargs):
image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None

try:
output_legacy_tensor = legacy(image_tensor)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
"This means that you need to specify the parameters passed to `make_images` through the "
"`make_images_kwargs` of the `ConsistencyConfig`."
) from exc

try:
output_prototype_tensor = prototype(image_tensor)
except Exception as exc:
raise AssertionError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
f"This means there is a consistency bug either in `_get_params` "
f"or in the `is_simple_tensor` path in `_transform`."
) from exc

torch.testing.assert_close(
output_prototype_tensor,
output_legacy_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
)

try:
output_prototype_image = prototype(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {tuple(image.shape)} failed with the error above. "
f"This means there is a consistency bug either in `_get_params` "
f"or in the `features.Image` path in `_transform`."
) from exc

torch.testing.assert_close(
torch.Tensor(output_prototype_image),
output_prototype_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
)

if image_pil is not None:
torch.testing.assert_close(
to_image_tensor(prototype(image_pil)),
to_image_tensor(legacy(image_pil)),
atol=0,
rtol=0,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
)
6 changes: 5 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def __init__(
) -> None:
super().__init__()

self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.size = (
[size]
if isinstance(size, int)
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
)
self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
Expand Down