Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,15 +1185,15 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


# TODO: I guess we need to change the name of this test. Should we have a
# _correctness test as well like the rest?
def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)

inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)


Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
)
return Image.wrap_like(self, output)

def normalize(self, mean: List[float], std: List[float], inplace: bool = False):
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Image.wrap_like(self, output)


ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output)

def normalize(self, mean: List[float], std: List[float], inplace: bool = False):
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Video.wrap_like(self, output)


VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output


# TODO: Are there tests for this class?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We have a pretty good grip on the functional testing, but so far the transforms testing is the wild west. Seems we have forgotten this completely.

class RandomPhotometricDistort(Transform):
_transformed_types = (
datapoints.Image,
Expand Down Expand Up @@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image):

orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]

elif isinstance(inpt, PIL.Image.Image):
if isinstance(orig_inpt, PIL.Image.Image):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the original version had a bug: at this point inpt was always a Tensor because of line 122, and so the condition on line 130 was never True and this function would always return a Tensor even if a PIL image was passed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

output = F.to_image_pil(output)

return output
Expand Down
20 changes: 8 additions & 12 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,14 @@ def normalize(
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
elif not is_simple_tensor(inpt):
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead."
)

# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.normalize(mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
Expand Down