diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 94bed3d7ba6..26fc9913874 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1308,7 +1308,13 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return F_t.erase(img, i, j, h, w, v, inplace=inplace) -def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: +def gaussian_blur( + img: Tensor, + kernel_size: List[int], + sigma: Optional[List[float]] = None, + pad_mode: str = "reflect", + pad_value: float = 0.0, +) -> Tensor: """Performs Gaussian blurring on the image by given kernel. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1330,6 +1336,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa .. note:: In torchscript mode sigma as single float is not supported, use a sequence of length 1: ``[sigma, ]``. + pad_mode (str): padding mode to be passed to the padding function. Default: "reflect". + pad_value (float): value to be passed to the padding function. Default: 0.0. Returns: PIL Image or Tensor: Gaussian Blurred version of the image. @@ -1368,7 +1376,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa t_img = pil_to_tensor(img) - output = F_t.gaussian_blur(t_img, kernel_size, sigma) + output = F_t.gaussian_blur(t_img, kernel_size, sigma, pad_mode, pad_value) if not isinstance(img, torch.Tensor): output = to_pil_image(output, mode=img.mode) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index aa40516d4cb..9e13990c2d8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -737,7 +737,9 @@ def _get_gaussian_kernel2d( return kernel2d -def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: +def gaussian_blur( + img: Tensor, kernel_size: List[int], sigma: List[float], pad_mode: str = "reflect", pad_value: float = 0.0 +) -> Tensor: if not (isinstance(img, torch.Tensor)): raise TypeError(f"img should be Tensor. Got {type(img)}") @@ -751,7 +753,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] - img = torch_pad(img, padding, mode="reflect") + img = torch_pad(img, padding, mode=pad_mode, value=pad_value) img = conv2d(img, kernel, groups=img.shape[-3]) img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 985937678a5..5578b732681 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1754,7 +1754,7 @@ class GaussianBlur(torch.nn.Module): """ - def __init__(self, kernel_size, sigma=(0.1, 2.0)): + def __init__(self, kernel_size, sigma=(0.1, 2.0), pad_mode="reflect", pad_value=0.0): super().__init__() _log_api_usage_once(self) self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") @@ -1772,7 +1772,18 @@ def __init__(self, kernel_size, sigma=(0.1, 2.0)): else: raise ValueError("sigma should be a single number or a list/tuple with length 2.") + if not isinstance(pad_value, numbers.Number): + raise ValueError("pad_value should be a single float number.") + + if pad_mode not in ["constant", "reflect", "replicate", "circular"]: + raise ValueError( + "Unknown padding mode: {}. Use one of 'constant', 'reflect', 'replicate', 'circular'".format(pad_mode) + ) + elif pad_mode != "constant" and pad_value != 0.0: + raise ValueError("Only constant padding mode supports non-zero pad_value.") self.sigma = sigma + self.pad_mode = pad_mode + self.pad_value = pad_value @staticmethod def get_params(sigma_min: float, sigma_max: float) -> float: @@ -1796,7 +1807,7 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: Gaussian blurred image """ sigma = self.get_params(self.sigma[0], self.sigma[1]) - return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) + return F.gaussian_blur(img, self.kernel_size, [sigma, sigma], self.pad_mode, self.pad_value) def __repr__(self) -> str: s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"