Skip to content

Commit efb0736

Browse files
authored
add tensor kernels for normalize and erase (#5462)
* add tensor kernels for normalize and erase * add image tensor assertion
1 parent c6b447b commit efb0736

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

torchvision/transforms/functional.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -338,30 +338,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
338338
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
339339
_log_api_usage_once(normalize)
340340
if not isinstance(tensor, torch.Tensor):
341-
raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
341+
raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
342342

343-
if not tensor.is_floating_point():
344-
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
345-
346-
if tensor.ndim < 3:
347-
raise ValueError(
348-
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
349-
)
350-
351-
if not inplace:
352-
tensor = tensor.clone()
353-
354-
dtype = tensor.dtype
355-
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
356-
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
357-
if (std == 0).any():
358-
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
359-
if mean.ndim == 1:
360-
mean = mean.view(-1, 1, 1)
361-
if std.ndim == 1:
362-
std = std.view(-1, 1, 1)
363-
tensor.sub_(mean).div_(std)
364-
return tensor
343+
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
365344

366345

367346
def resize(
@@ -1281,11 +1260,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
12811260
if not isinstance(img, torch.Tensor):
12821261
raise TypeError(f"img should be Tensor Image. Got {type(img)}")
12831262

1284-
if not inplace:
1285-
img = img.clone()
1286-
1287-
img[..., i : i + h, j : j + w] = v
1288-
return img
1263+
return F_t.erase(img, i, j, h, w, v, inplace=inplace)
12891264

12901265

12911266
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:

torchvision/transforms/functional_tensor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,3 +918,40 @@ def equalize(img: Tensor) -> Tensor:
918918
return _equalize_single_image(img)
919919

920920
return torch.stack([_equalize_single_image(x) for x in img])
921+
922+
923+
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
924+
_assert_image_tensor(tensor)
925+
926+
if not tensor.is_floating_point():
927+
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
928+
929+
if tensor.ndim < 3:
930+
raise ValueError(
931+
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
932+
)
933+
934+
if not inplace:
935+
tensor = tensor.clone()
936+
937+
dtype = tensor.dtype
938+
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
939+
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
940+
if (std == 0).any():
941+
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
942+
if mean.ndim == 1:
943+
mean = mean.view(-1, 1, 1)
944+
if std.ndim == 1:
945+
std = std.view(-1, 1, 1)
946+
tensor.sub_(mean).div_(std)
947+
return tensor
948+
949+
950+
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
951+
_assert_image_tensor(img)
952+
953+
if not inplace:
954+
img = img.clone()
955+
956+
img[..., i : i + h, j : j + w] = v
957+
return img

0 commit comments

Comments
 (0)