Skip to content

Commit 67d1f1f

Browse files
Yosua Michael Maranathapmeier
authored andcommitted
[fbsync] [prototype] Port elastic and minor cleanups (#6942)
Summary: * Port elastic and minor cleanups * Update torchvision/prototype/transforms/functional/_geometry.py Reviewed By: NicolasHug Differential Revision: D41265205 fbshipit-source-id: c62a7764bc7bf3efecb9bbb24ce1f192226c4438 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent 3306159 commit 67d1f1f

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
388388
maximum = float_image.amax(dim=(-2, -1), keepdim=True)
389389

390390
eq_idxs = maximum == minimum
391-
inv_scale = maximum.sub_(minimum).div_(bound)
391+
inv_scale = maximum.sub_(minimum).mul_(1.0 / bound)
392392
minimum[eq_idxs] = 0.0
393393
inv_scale[eq_idxs] = 1.0
394394

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy(
390390
device=device,
391391
)
392392
new_points = torch.matmul(points, transposed_affine_matrix)
393-
tr, _ = torch.min(new_points, dim=0, keepdim=True)
393+
tr = torch.amin(new_points, dim=0, keepdim=True)
394394
# Translate bounding boxes
395395
out_bboxes.sub_(tr.repeat((1, 2)))
396396
# Estimate meta-data for image with inverted=True and with center=[0,0]
@@ -701,7 +701,7 @@ def pad_image_tensor(
701701
# internally.
702702
torch_padding = _parse_pad_padding(padding)
703703

704-
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
704+
if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
705705
raise ValueError(
706706
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
707707
f"but got `'{padding_mode}'`."
@@ -917,17 +917,17 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
917917
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
918918
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
919919
#
920-
920+
# TODO: should we define them transposed?
921921
theta1 = torch.tensor(
922922
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
923923
)
924924
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
925925

926926
d = 0.5
927927
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
928-
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
928+
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
929929
base_grid[..., 0].copy_(x_grid)
930-
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
930+
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
931931
base_grid[..., 1].copy_(y_grid)
932932
base_grid[..., 2].fill_(1)
933933

@@ -1059,6 +1059,7 @@ def perspective_bounding_box(
10591059
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
10601060
]
10611061

1062+
# TODO: should we define them transposed?
10621063
theta1 = torch.tensor(
10631064
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
10641065
dtype=dtype,
@@ -1165,14 +1166,17 @@ def elastic_image_tensor(
11651166
return image
11661167

11671168
shape = image.shape
1169+
device = image.device
11681170

11691171
if image.ndim > 4:
11701172
image = image.reshape((-1,) + shape[-3:])
11711173
needs_unsquash = True
11721174
else:
11731175
needs_unsquash = False
11741176

1175-
output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
1177+
image_height, image_width = shape[-2:]
1178+
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
1179+
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill)
11761180

11771181
if needs_unsquash:
11781182
output = output.reshape(shape)
@@ -1505,8 +1509,7 @@ def five_crop_image_tensor(
15051509
image_height, image_width = image.shape[-2:]
15061510

15071511
if crop_width > image_width or crop_height > image_height:
1508-
msg = "Requested crop size {} is bigger than input size {}"
1509-
raise ValueError(msg.format(size, (image_height, image_width)))
1512+
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
15101513

15111514
tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
15121515
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
@@ -1525,8 +1528,7 @@ def five_crop_image_pil(
15251528
image_height, image_width = get_spatial_size_image_pil(image)
15261529

15271530
if crop_width > image_width or crop_height > image_height:
1528-
msg = "Requested crop size {} is bigger than input size {}"
1529-
raise ValueError(msg.format(size, (image_height, image_width)))
1531+
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
15301532

15311533
tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
15321534
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)

0 commit comments

Comments
 (0)