From 2544176396665876ec35901d114859cd6fa38386 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 09:24:04 +0100 Subject: [PATCH] [PoC] simplify simple tensor fallback heuristic --- torchvision/transforms/v2/_transform.py | 61 ++++++++----------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index f83ed5d6e11..c60dbafa5ef 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -1,14 +1,15 @@ from __future__ import annotations import enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type import PIL.Image import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints -from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor +from torchvision.datapoints._datapoint import Datapoint +from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.utils import _log_api_usage_once @@ -16,7 +17,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: Tuple[Type, ...] = (Datapoint, PIL.Image.Image) def __init__(self) -> None: super().__init__() @@ -32,53 +33,29 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError def forward(self, *inputs: Any) -> Any: - flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + sample = inputs if len(inputs) > 1 else inputs[0] + if is_simple_tensor(sample): + sample = datapoints.Image(sample) + simple_tensor_image_fallback = True + else: + simple_tensor_image_fallback = False + + flat_inputs, spec = tree_flatten(sample) self._check_inputs(flat_inputs) - needs_transform_list = self._needs_transform_list(flat_inputs) - params = self._get_params( - [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] - ) + params = self._get_params(flat_inputs) flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt - for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs ] - return tree_unflatten(flat_outputs, spec) + outputs = tree_unflatten(flat_outputs, spec) + + if simple_tensor_image_fallback: + outputs = outputs.as_subclass(torch.Tensor) - def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: - # Below is a heuristic on how to deal with simple tensor inputs: - # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image - # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. - # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is - # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` - # of `tree_flatten`, which recurses depth-first through the input. - # - # This heuristic stems from two requirements: - # 1. We need to keep BC for single input simple tensors and treat them as images. - # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` - # return supplemental numerical data as tensors that cannot be transformed as images. - # - # The heuristic should work well for most people in practice. The only case where it doesn't is if someone - # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. - # However, this case wasn't supported by transforms v1 either, so there is no BC concern. - - needs_transform_list = [] - transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) - for inpt in flat_inputs: - needs_transform = True - - if not check_type(inpt, self._transformed_types): - needs_transform = False - elif is_simple_tensor(inpt): - if transform_simple_tensor: - transform_simple_tensor = False - else: - needs_transform = False - needs_transform_list.append(needs_transform) - return needs_transform_list + return outputs def extra_repr(self) -> str: extra = []