Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
VideoClassificationEval,
OpticalFlowEval,
)
from ._type_conversion import DecodeImage, LabelToOneHot
from ._type_conversion import DecodeImage, LabelToOneHot, ToTensor, ImageToPIL, PILToTensor
34 changes: 33 additions & 1 deletion torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import is_simple_tensor


class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -33,3 +37,31 @@ def extra_repr(self) -> str:
return ""

return f"num_categories={self.num_categories}"


class ToTensor(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (PIL.Image.Image, np.ndarray)):
return F.to_tensor(input)
else:
return input


class PILToTensor(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, PIL.Image.Image):
return F.pil_to_tensor(input)
else:
return input


class ImageToPIL(Transform):
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)):
return F.image_to_pil(input, mode=self.mode)
else:
return input
9 changes: 8 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
label_to_one_hot,
to_tensor,
pil_to_tensor,
image_to_pil,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.functional import one_hot
from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as F


def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
Expand All @@ -23,3 +24,8 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor

def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]


to_tensor = F.to_tensor
pil_to_tensor = F.pil_to_tensor
image_to_pil = F.to_pil_image