1
1
import warnings
2
- from typing import Optional , Dict , Tuple
2
+ from typing import Optional , Tuple
3
3
4
4
import torch
5
5
from torch import Tensor
@@ -45,6 +45,12 @@ def _max_value(dtype: torch.dtype) -> float:
45
45
return max_value .item ()
46
46
47
47
48
+ def _assert_channels (img : Tensor , permitted : List [int ]) -> None :
49
+ c = _get_image_num_channels (img )
50
+ if c not in permitted :
51
+ raise TypeError ("Input image tensor permitted channel values are {}, but found {}" .format (permitted , c ))
52
+
53
+
48
54
def convert_image_dtype (image : torch .Tensor , dtype : torch .dtype = torch .float ) -> torch .Tensor :
49
55
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly
50
56
@@ -210,9 +216,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
210
216
"""
211
217
if img .ndim < 3 :
212
218
raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
213
- c = img .shape [- 3 ]
214
- if c != 3 :
215
- raise TypeError ("Input image tensor should 3 channels, but found {}" .format (c ))
219
+ _assert_channels (img , [3 ])
216
220
217
221
if num_output_channels not in (1 , 3 ):
218
222
raise ValueError ('num_output_channels should be either 1 or 3' )
@@ -230,7 +234,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
230
234
231
235
232
236
def adjust_brightness (img : Tensor , brightness_factor : float ) -> Tensor :
233
- """PRIVATE METHOD. Adjust brightness of an RGB image.
237
+ """PRIVATE METHOD. Adjust brightness of a Grayscale or RGB image.
234
238
235
239
.. warning::
236
240
@@ -252,6 +256,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
252
256
if not _is_tensor_a_torch_image (img ):
253
257
raise TypeError ('tensor is not a torch image.' )
254
258
259
+ _assert_channels (img , [1 , 3 ])
260
+
255
261
return _blend (img , torch .zeros_like (img ), brightness_factor )
256
262
257
263
@@ -278,14 +284,16 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
278
284
if not _is_tensor_a_torch_image (img ):
279
285
raise TypeError ('tensor is not a torch image.' )
280
286
287
+ _assert_channels (img , [3 ])
288
+
281
289
dtype = img .dtype if torch .is_floating_point (img ) else torch .float32
282
290
mean = torch .mean (rgb_to_grayscale (img ).to (dtype ), dim = (- 3 , - 2 , - 1 ), keepdim = True )
283
291
284
292
return _blend (img , mean , contrast_factor )
285
293
286
294
287
295
def adjust_hue (img : Tensor , hue_factor : float ) -> Tensor :
288
- """PRIVATE METHOD. Adjust hue of an image.
296
+ """PRIVATE METHOD. Adjust hue of an RGB image.
289
297
290
298
.. warning::
291
299
@@ -320,6 +328,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
320
328
if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
321
329
raise TypeError ('Input img should be Tensor image' )
322
330
331
+ _assert_channels (img , [3 ])
332
+
323
333
orig_dtype = img .dtype
324
334
if img .dtype == torch .uint8 :
325
335
img = img .to (dtype = torch .float32 ) / 255.0
@@ -359,11 +369,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
359
369
if not _is_tensor_a_torch_image (img ):
360
370
raise TypeError ('tensor is not a torch image.' )
361
371
372
+ _assert_channels (img , [3 ])
373
+
362
374
return _blend (img , rgb_to_grayscale (img ), saturation_factor )
363
375
364
376
365
377
def adjust_gamma (img : Tensor , gamma : float , gain : float = 1 ) -> Tensor :
366
- r"""PRIVATE METHOD. Adjust gamma of an RGB image.
378
+ r"""PRIVATE METHOD. Adjust gamma of a Grayscale or RGB image.
367
379
368
380
.. warning::
369
381
@@ -391,6 +403,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
391
403
if not isinstance (img , torch .Tensor ):
392
404
raise TypeError ('Input img should be a Tensor.' )
393
405
406
+ _assert_channels (img , [1 , 3 ])
407
+
394
408
if gamma < 0 :
395
409
raise ValueError ('Gamma should be a non-negative real number' )
396
410
0 commit comments