Skip to content

Commit f2e2f9a

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Create kernel on-device for transforms.functional.gaussian_blur (#8426)
Reviewed By: vmoens Differential Revision: D58283858 fbshipit-source-id: a4df173fcafe9bce4b35478a7eab5f66f2579180 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent b2a965d commit f2e2f9a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/transforms/_functional_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,10 @@ def perspective(
722722
return _apply_grid_transform(img, grid, interpolation, fill=fill)
723723

724724

725-
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
725+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
726726
ksize_half = (kernel_size - 1) * 0.5
727727

728-
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
728+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device)
729729
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
730730
kernel1d = pdf / pdf.sum()
731731

@@ -735,8 +735,8 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
735735
def _get_gaussian_kernel2d(
736736
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
737737
) -> Tensor:
738-
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
739-
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
738+
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
739+
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
740740
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
741741
return kernel2d
742742

0 commit comments

Comments
 (0)