Skip to content

add prototype features #4721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 4, 2021
Merged

add prototype features #4721

merged 16 commits into from
Nov 4, 2021

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 22, 2021

For a seamless interplay between the new prototype datasets and upcoming transforms, the datasets need to forward information about what each item of a sample is in some form or way. In pmeier/torchvision-datasets-rework#1it was decided that the best way to do this is by subclassing torch.Tensor and wrap the items in custom classes for each type.

An additional benefit is that these custom Feature's can carry extra meta data. For example in this PR I've added a color_space attribute for images and a category attribute for labels`:

from torchvision.prototype import datasets

for sample in datasets.load("fashionmnist"):
    image, label = sample["image"], sample["label"]
    print(image.shape, image.color_space)
    print(label, label.category)
    break
torch.Size([28, 28]) ColorSpace.GRAYSCALE
tensor(6) Shirt

cc @pmeier @vfdev-5 @datumbox @bjuncek

Comment on lines 37 to 50
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output

return cls(output, like=args[0])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This disables __torch_function__ completely. That means every operation that is performed on a Feature will return a torch.Tensor rather than the original type. The only exception to this is .clone().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm not against this decision, I'm curious about why you did it this way.

Also, what is the recommended way for users to get the data out of the Feature into a Tensor, would it be torch.as_tensor(image) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm not against this decision, I'm curious about why you did it this way.

Given the amount of operators that torch supports, it is near impossible for us to make sure that a feature is still its original type after an operation is performed. For example is Label(torch.tensor(0) - 1 still a Label? Because Without disabling __torch_function__, it will be. By disabling this, the user has to manually make it a Label again, if for the use case negative labels are indeed allowed:

label = Label(torch.tensor(0))
transformed_label = Label(label - 1, like=label)

Clone is the only exception (for now, maybe there are more that also fit this) since we can be fairly sure that cloning a feature should indeed return the original type including its meta data.

Also, what is the recommended way for users to get the data out of the Feature into a Tensor, would it be torch.as_tensor(image) ?

All Feature's are Tensor's so users can access them as they normally do. If for some reason they want a vanilla tensor, feature_tensor.data is enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we probably need to revisit this, because torch.script seems to ignore __torch_function__: https://app.circleci.com/pipelines/github/pytorch/vision/11754/workflows/d1c28de6-1150-496c-b630-9d2059233a5e/jobs/916369/steps. I'll investigate if this is a limitation I was not aware of or if my implementation is wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature_tensor.data is enough

Note that .data has been deprecated a long time ago, so it's probably not what we want to encourage users to use.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to chime into this discussion. As far as I understand the purpose of these tensor subclasses: Feature, Image, Label is that we could correctly dispatch torchvision transformations on image, bboxes, masks, labels etc.

  • lifetime of Feature, Image, Label subclasses. IMO they should somehow persist through all transforms.

For example, for a detection task with very basic "hflip"-like transform we have the following transformations to apply on the image and the set of bboxes:

transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.ConvertImageDtype(torch.float),
])

When we apply torch.flip from RandomHorizontalFlip we may still want to know that ConvertImageDtype is only applicable to images and not bboxes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but inside HorizontalFlip we can simply manually do

return BoundingBox(torch.flip(input), like=input)

We can also go a step further and wrap each feature transform for more convenience:

def convert_output_type(feature_transform):
    def wrapper(input, **kwargs):
        input_type = type(input)
        output = feature_transform(input, kwargs)
        if type(output) is not input_type:
            output = input_type(output, like=input)
        return output

    return wrapper


@convert_output_type
def horizontal_flip(input: BoundingBox):
    return input.flip()

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Looks like a very interesting approach. Excuse the many questions, since I didn't follow up the development of this closely, I probably miss context. I've raised some points but if you and Francisco feel it's covered feel free to dismiss. I also assume that you have made sure that this solution is JIT compatible.

T = TypeVar("T", bound="Feature")


class Feature(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk about the reasons this is named Feature? The term is quite overloaded in ML and might cause confusion. Having said that, off the top of my head I don't have a good replacement to offer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the name that tensorflow_datasets uses. I agree it is quite overloaded, but unless we come up with a better one, I wouldn't change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another reason other than keeping it consistent with TF? I'm trying to understand where the "Feature" comes from. I get it for the Image but since the same class is parent of Labels it's less intuitive why. Both of them are data points and together form a record. Again I don't have a better name to suggest but something to consider.

Copy link
Collaborator Author

@pmeier pmeier Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another reason other than keeping it consistent with TF?

No, none that I can think of. Happy to change it if you have a better proposal.

class ColorSpace(enum.Enum):
OTHER = 0
GRAYSCALE = 1
RGB = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about Transparency that we already support? Also does it make sense to align the two enums:

class ImageReadMode(Enum):
"""
Support for various modes while reading images.
Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
``ImageReadMode.GRAY`` for converting to grayscale,
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
"""
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should align the two. Maybe we can go with ImageMode as a unified name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great. FYI the ImageReadMode has an equivalent C++ definition which also needs to be synced:

using ImageReadMode = int64_t;
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0;
const ImageReadMode IMAGE_READ_MODE_GRAY = 1;
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'm approving to unblock. I assume you'll do the remaining of the dataset changes to use Feature in a follow-up PR?

feature_types = pytest.mark.parametrize(
"feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for serialization of the feature? Like torch.save / torch.load, to make sure that the extra metadata added is not lost on the way?

if decoder is raw:
image = torch.from_numpy(image_array)
image = Image(image_array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I've brought this in the past about performance, but one thing we should keep in mind is that having the same field represent two different entities can be quite confusing. Might be worth considering renaming one or the other. But it's up for discussion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that changing the key value pair (dict(image=<torch.Tensor>) vs. dict(image_buffer=<file handle>)) depending on the value of decoder is more confusing than just changing the value (dict(image=<torch.Tensor or file handle>))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it up for discussion and gather feedback from users then

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a larger discussion to be had about how or if we standardize the sample keys. For example should we always have an "input" key? That would make it easier, because users could always use that regardless of the dataset. On the other hand it will get weird for datasets that do not have a clear input, for example when returning a pair of images.

Comment on lines 37 to 50
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output

return cls(output, like=args[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm not against this decision, I'm curious about why you did it this way.

Also, what is the recommended way for users to get the data out of the Feature into a Tensor, would it be torch.as_tensor(image) ?

Comment on lines 37 to 50
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output

return cls(output, like=args[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to chime into this discussion. As far as I understand the purpose of these tensor subclasses: Feature, Image, Label is that we could correctly dispatch torchvision transformations on image, bboxes, masks, labels etc.

  • lifetime of Feature, Image, Label subclasses. IMO they should somehow persist through all transforms.

For example, for a detection task with very basic "hflip"-like transform we have the following transformations to apply on the image and the set of bboxes:

transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.ConvertImageDtype(torch.float),
])

When we apply torch.flip from RandomHorizontalFlip we may still want to know that ConvertImageDtype is only applicable to images and not bboxes.

Comment on lines 168 to 170
print(actual_jit)
print(expected)
assert_close(actual_jit, expected)
Copy link
Collaborator Author

@pmeier pmeier Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails, but I was unable to get it to pass. The eager variant works as expected, but the jitted one gives incorrect output:

tensor([2.6000, 1.6000, 2.6000, 1.6000])
BoundingBox([ 6,  1, 10,  3])
AssertionError: The values for attribute 'dtype' do not match: torch.float32 != torch.int64.

I'm not sure if the test setup is wrong somehow or of JIT simply doesn't work with tensor subclasses that carry additional meta data. This needs to be reviewed by someone with more JIT experience than I have. cc @fmassa @datumbox

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinging @Chillee here, but my gut feeling is that attributes of the Tensor subclass might be embedded in the graph as constants.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the computation graph of the traced model (via print(traced_function.code)) to see if the image_size was embedded as a constant or not.

Another thing would be to make the image_size for sample_inputs to be the same of for the input, which should make tests pass (but it is still a bug)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are probably right:

def compose(input: Tensor,
    size: Tensor) -> Tensor:
  new_height, new_width, = torch.unbind(size)
  height_scale = torch.div(new_height, CONSTANTS.c0)
  width_scale = torch.div(new_width, CONSTANTS.c0)
  data = torch.clone(input)
  data0 = torch.to(data, torch.device("cpu"), 6)
  old_x1, old_y1, old_x2, old_y2, = torch.unbind(torch.detach(data0), -1)
  a = torch.mul(old_x1, width_scale)
  b = torch.mul(old_y1, height_scale)
  c = torch.mul(old_x2, width_scale)
  d = torch.mul(old_y2, height_scale)
  data1 = torch.stack([a, b, c, d], -1)
  data2 = torch.to(data1, torch.device("cpu"), 6)
  x, y, x2, y2, = torch.unbind(torch.detach(data2), -1)
  c0 = torch.sub(x2, x)
  d0 = torch.sub(y2, y)
  data3 = torch.stack([x, y, c0, d0], -1)
  data4 = torch.to(data3, torch.device("cpu"), 6)
  x0, b0, w, d1, = torch.unbind(torch.detach(data4), -1)
  a0 = torch.rsub(torch.add(x0, w), 5)
  data5 = torch.stack([a0, b0, w, d1], -1)
  data6 = torch.to(data5, torch.device("cpu"), 6)
  x1, y1, w0, h, = torch.unbind(torch.detach(data6), -1)
  c1 = torch.add(x1, w0)
  d2 = torch.add(y1, h)
  data7 = torch.stack([x1, y1, c1, d2], -1)
  data8 = torch.to(data7, torch.device("cpu"), 6)
  return torch.detach(data8)

The important lines are

new_height, new_width, = torch.unbind(size)
height_scale = torch.div(new_height, CONSTANTS.c0)
width_scale = torch.div(new_width, CONSTANTS.c0)

which come from

old_height, old_width = input.image_size
new_height, new_width = size

height_scale = new_height / old_height
width_scale = new_width / old_width

So old_height and old_width, which come directly from the BoundingBox.image_size meta data, are stored as constants.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I propose: let's forget torchscript for now and continue with the prototype

@pmeier pmeier requested a review from fmassa November 3, 2021 07:46
@pmeier pmeier merged commit fe78a8a into pytorch:main Nov 4, 2021
@pmeier pmeier deleted the prototype/features branch November 4, 2021 09:59
facebook-github-bot pushed a commit that referenced this pull request Nov 8, 2021
Summary:
* add prototype features

* add some JIT tests

* refactor input data handling

* refactor tests

* cleanup tests

* add BoundingBox feature

* mypy

* xfail torchscript tests for now

* cleanup

* fix imports

Reviewed By: kazhang

Differential Revision: D32216676

fbshipit-source-id: a2579470737c6c533cea45aa5a9ad45d5e23aec7
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* add prototype features

* add some JIT tests

* refactor input data handling

* refactor tests

* cleanup tests

* add BoundingBox feature

* mypy

* xfail torchscript tests for now

* cleanup

* fix imports
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants