Skip to content

[PoC] simplify simple tensor fallback heuristic #7340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from __future__ import annotations

import enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type

import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.datapoints._datapoint import Datapoint
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.utils import _log_api_usage_once


class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
_transformed_types: Tuple[Type, ...] = (Datapoint, PIL.Image.Image)

def __init__(self) -> None:
super().__init__()
Expand All @@ -32,53 +33,29 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError

def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
sample = inputs if len(inputs) > 1 else inputs[0]
if is_simple_tensor(sample):
sample = datapoints.Image(sample)
simple_tensor_image_fallback = True
else:
simple_tensor_image_fallback = False

flat_inputs, spec = tree_flatten(sample)

self._check_inputs(flat_inputs)

needs_transform_list = self._needs_transform_list(flat_inputs)
params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
params = self._get_params(flat_inputs)

flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we no longer care about simple tensors here, we can use the regular isinstance and can potentially eliminate

def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:

]

return tree_unflatten(flat_outputs, spec)
outputs = tree_unflatten(flat_outputs, spec)

if simple_tensor_image_fallback:
outputs = outputs.as_subclass(torch.Tensor)

def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.

needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True

if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_simple_tensor(inpt):
if transform_simple_tensor:
transform_simple_tensor = False
else:
needs_transform = False
needs_transform_list.append(needs_transform)
return needs_transform_list
return outputs

def extra_repr(self) -> str:
extra = []
Expand Down