From 8976cf42e2e7c077d9f12fa2e84a5f9eceef1124 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 9 Nov 2022 10:41:14 +0000 Subject: [PATCH 1/2] Avoid double casting on adjust_contrast --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- torchvision/prototype/transforms/functional/_meta.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 12fa5288abc..33605a98091 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -80,8 +80,8 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _rgb_to_gray(image) if c == 3 else image - mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) + grayscale_image = _rgb_to_gray(image, cast=False) if c == 3 else image.to(dtype) + mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 8bcd8176733..a1c5e4723c4 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.repeat(repeats) -def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: +def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: r, g, b = image.unbind(dim=-3) - l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) - l_img = l_img.to(image.dtype).unsqueeze(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + if cast: + l_img = l_img.to(image.dtype) + l_img = l_img.unsqueeze(dim=-3) return l_img From 01fd4346a1251377b8be86362779b345207d5f3a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 9 Nov 2022 10:51:31 +0000 Subject: [PATCH 2/2] Handle properly ints. --- torchvision/prototype/transforms/functional/_color.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 33605a98091..e46dfd74935 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -79,8 +79,13 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - grayscale_image = _rgb_to_gray(image, cast=False) if c == 3 else image.to(dtype) + fp = image.is_floating_point() + if c == 3: + grayscale_image = _rgb_to_gray(image, cast=False) + if not fp: + grayscale_image = grayscale_image.floor_() + else: + grayscale_image = image if fp else image.to(torch.float32) mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) return _blend(image, mean, contrast_factor)