Skip to content

Commit 7f1a05a

Browse files
authored
Check num of channels on adjust_* transformations (#3069)
* Fixing upperbound value on tests and documentation. * Limit the number of channels on adjust_* transoforms.
1 parent 0ebbb0a commit 7f1a05a

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

test/common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,13 @@ def freeze_rng_state():
339339
class TransformsTester(unittest.TestCase):
340340

341341
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
342-
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
342+
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
343343
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
344344
return tensor, pil_img
345345

346346
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
347347
batch_tensor = torch.randint(
348-
0, 255,
348+
0, 256,
349349
(num_samples, channels, height, width),
350350
dtype=torch.uint8,
351351
device=device

torchvision/transforms/functional_tensor.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Dict, Tuple
2+
from typing import Optional, Tuple
33

44
import torch
55
from torch import Tensor
@@ -45,6 +45,12 @@ def _max_value(dtype: torch.dtype) -> float:
4545
return max_value.item()
4646

4747

48+
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
49+
c = _get_image_num_channels(img)
50+
if c not in permitted:
51+
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))
52+
53+
4854
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
4955
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly
5056
@@ -210,9 +216,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
210216
"""
211217
if img.ndim < 3:
212218
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
213-
c = img.shape[-3]
214-
if c != 3:
215-
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
219+
_assert_channels(img, [3])
216220

217221
if num_output_channels not in (1, 3):
218222
raise ValueError('num_output_channels should be either 1 or 3')
@@ -230,7 +234,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
230234

231235

232236
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
233-
"""PRIVATE METHOD. Adjust brightness of an RGB image.
237+
"""PRIVATE METHOD. Adjust brightness of a Grayscale or RGB image.
234238
235239
.. warning::
236240
@@ -252,6 +256,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
252256
if not _is_tensor_a_torch_image(img):
253257
raise TypeError('tensor is not a torch image.')
254258

259+
_assert_channels(img, [1, 3])
260+
255261
return _blend(img, torch.zeros_like(img), brightness_factor)
256262

257263

@@ -278,14 +284,16 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
278284
if not _is_tensor_a_torch_image(img):
279285
raise TypeError('tensor is not a torch image.')
280286

287+
_assert_channels(img, [3])
288+
281289
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
282290
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
283291

284292
return _blend(img, mean, contrast_factor)
285293

286294

287295
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
288-
"""PRIVATE METHOD. Adjust hue of an image.
296+
"""PRIVATE METHOD. Adjust hue of an RGB image.
289297
290298
.. warning::
291299
@@ -320,6 +328,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
320328
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
321329
raise TypeError('Input img should be Tensor image')
322330

331+
_assert_channels(img, [3])
332+
323333
orig_dtype = img.dtype
324334
if img.dtype == torch.uint8:
325335
img = img.to(dtype=torch.float32) / 255.0
@@ -359,11 +369,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
359369
if not _is_tensor_a_torch_image(img):
360370
raise TypeError('tensor is not a torch image.')
361371

372+
_assert_channels(img, [3])
373+
362374
return _blend(img, rgb_to_grayscale(img), saturation_factor)
363375

364376

365377
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
366-
r"""PRIVATE METHOD. Adjust gamma of an RGB image.
378+
r"""PRIVATE METHOD. Adjust gamma of a Grayscale or RGB image.
367379
368380
.. warning::
369381
@@ -391,6 +403,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
391403
if not isinstance(img, torch.Tensor):
392404
raise TypeError('Input img should be a Tensor.')
393405

406+
_assert_channels(img, [1, 3])
407+
394408
if gamma < 0:
395409
raise ValueError('Gamma should be a non-negative real number')
396410

0 commit comments

Comments
 (0)