diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 12fa5288abc..e46dfd74935 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _rgb_to_gray(image) if c == 3 else image - mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) + fp = image.is_floating_point() + if c == 3: + grayscale_image = _rgb_to_gray(image, cast=False) + if not fp: + grayscale_image = grayscale_image.floor_() + else: + grayscale_image = image if fp else image.to(torch.float32) + mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 8bcd8176733..a1c5e4723c4 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.repeat(repeats) -def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: +def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: r, g, b = image.unbind(dim=-3) - l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) - l_img = l_img.to(image.dtype).unsqueeze(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + if cast: + l_img = l_img.to(image.dtype) + l_img = l_img.unsqueeze(dim=-3) return l_img