Skip to content

Commit bc4273d

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Fixed issues with dtype in geom functional transforms v2 (#7211)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: vmoens Differential Revision: D44416263 fbshipit-source-id: 4bf99470ac106dd8d1c15fa2e217e865508650d4
1 parent 46e07ad commit bc4273d

7 files changed

+101
-56
lines changed

test/prototype_common_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def make_image_loaders(
304304
"RGBA",
305305
),
306306
extra_dims=DEFAULT_EXTRA_DIMS,
307-
dtypes=(torch.float32, torch.uint8),
307+
dtypes=(torch.float32, torch.float64, torch.uint8),
308308
constant_alpha=True,
309309
):
310310
for params in combinations_grid(size=sizes, color_space=color_spaces, extra_dims=extra_dims, dtype=dtypes):
@@ -426,7 +426,7 @@ def make_bounding_box_loaders(
426426
extra_dims=DEFAULT_EXTRA_DIMS,
427427
formats=tuple(datapoints.BoundingBoxFormat),
428428
spatial_size="random",
429-
dtypes=(torch.float32, torch.int64),
429+
dtypes=(torch.float32, torch.float64, torch.int64),
430430
):
431431
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
432432
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
@@ -618,7 +618,7 @@ def make_video_loaders(
618618
),
619619
num_frames=(1, 0, "random"),
620620
extra_dims=DEFAULT_EXTRA_DIMS,
621-
dtypes=(torch.uint8,),
621+
dtypes=(torch.uint8, torch.float32, torch.float64),
622622
):
623623
for params in combinations_grid(
624624
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes

test/prototype_transforms_kernel_infos.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
109109
}
110110

111111

112+
def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6):
113+
return {
114+
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
115+
}
116+
117+
112118
def pil_reference_wrapper(pil_kernel):
113119
@functools.wraps(pil_kernel)
114120
def wrapper(input_tensor, *other_args, **kwargs):
@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
541547
def transform(bbox, affine_matrix_, format_):
542548
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
543549
in_dtype = bbox.dtype
550+
if not torch.is_floating_point(bbox):
551+
bbox = bbox.float()
544552
bbox_xyxy = F.convert_format_bounding_box(
545-
bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
553+
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
546554
)
547555
points = np.array(
548556
[
@@ -560,6 +568,7 @@ def transform(bbox, affine_matrix_, format_):
560568
np.max(transformed_points[:, 0]).item(),
561569
np.max(transformed_points[:, 1]).item(),
562570
],
571+
dtype=bbox_xyxy.dtype,
563572
)
564573
out_bbox = F.convert_format_bounding_box(
565574
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
@@ -844,6 +853,10 @@ def sample_inputs_rotate_video():
844853
KernelInfo(
845854
F.rotate_bounding_box,
846855
sample_inputs_fn=sample_inputs_rotate_bounding_box,
856+
closeness_kwargs={
857+
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
858+
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
859+
},
847860
),
848861
KernelInfo(
849862
F.rotate_mask,
@@ -1275,6 +1288,8 @@ def sample_inputs_perspective_video():
12751288
**pil_reference_pixel_difference(2, mae=True),
12761289
**cuda_vs_cpu_pixel_difference(),
12771290
**float32_vs_uint8_pixel_difference(),
1291+
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
1292+
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
12781293
},
12791294
),
12801295
KernelInfo(
@@ -1294,7 +1309,11 @@ def sample_inputs_perspective_video():
12941309
KernelInfo(
12951310
F.perspective_video,
12961311
sample_inputs_fn=sample_inputs_perspective_video,
1297-
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1312+
closeness_kwargs={
1313+
**cuda_vs_cpu_pixel_difference(),
1314+
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
1315+
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
1316+
},
12981317
),
12991318
]
13001319
)

test/test_prototype_transforms_consistency.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,28 @@ def __init__(
138138
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
139139
],
140140
),
141-
ConsistencyConfig(
142-
prototype_transforms.LinearTransformation,
143-
legacy_transforms.LinearTransformation,
144-
[
145-
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN),
146-
],
147-
# Make sure that the product of the height, width and number of channels matches the number of elements in
148-
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
149-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]),
150-
supports_pil=False,
151-
),
141+
*[
142+
ConsistencyConfig(
143+
prototype_transforms.LinearTransformation,
144+
legacy_transforms.LinearTransformation,
145+
[
146+
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
147+
],
148+
# Make sure that the product of the height, width and number of channels matches the number of elements in
149+
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
150+
make_images_kwargs=dict(
151+
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
152+
),
153+
supports_pil=False,
154+
)
155+
for matrix_dtype, image_dtype in [
156+
(torch.float32, torch.float32),
157+
(torch.float64, torch.float64),
158+
(torch.float32, torch.uint8),
159+
(torch.float64, torch.float32),
160+
(torch.float32, torch.float64),
161+
]
162+
],
152163
ConsistencyConfig(
153164
prototype_transforms.Grayscale,
154165
legacy_transforms.Grayscale,

test/test_prototype_transforms_functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
142142
actual,
143143
expected,
144144
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
145-
msg=parametrized_error_message(*other_args, **kwargs),
145+
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
146146
)
147147

148148
def _unbatch(self, batch, *, data_dims):

torchvision/prototype/transforms/_misc.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso
6464
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
6565
)
6666

67+
if transformation_matrix.dtype != mean_vector.dtype:
68+
raise ValueError(
69+
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
70+
)
71+
6772
self.transformation_matrix = transformation_matrix
6873
self.mean_vector = mean_vector
6974

@@ -93,7 +98,9 @@ def _transform(
9398
)
9499

95100
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
96-
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
101+
102+
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
103+
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
97104
return transformed_tensor.reshape(shape)
98105

99106

torchvision/prototype/transforms/functional/_geometry.py

+37-35
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
404404

405405

406406
def _apply_grid_transform(
407-
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
407+
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT
408408
) -> torch.Tensor:
409409

410+
# We are using context knowledge that grid should have float dtype
411+
fp = img.dtype == grid.dtype
412+
float_img = img if fp else img.to(grid.dtype)
413+
410414
shape = float_img.shape
411415
if shape[0] > 1:
412416
# Apply same grid to a batch of images
@@ -433,7 +437,9 @@ def _apply_grid_transform(
433437
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
434438
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
435439

436-
return float_img
440+
img = float_img.round_().to(img.dtype) if not fp else float_img
441+
442+
return img
437443

438444

439445
def _assert_grid_transform_inputs(
@@ -511,7 +517,6 @@ def affine_image_tensor(
511517

512518
shape = image.shape
513519
ndim = image.ndim
514-
fp = torch.is_floating_point(image)
515520

516521
if ndim > 4:
517522
image = image.reshape((-1,) + shape[-3:])
@@ -535,13 +540,10 @@ def affine_image_tensor(
535540

536541
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
537542

538-
dtype = image.dtype if fp else torch.float32
543+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
539544
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
540545
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
541-
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
542-
543-
if not fp:
544-
output = output.round_().to(image.dtype)
546+
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
545547

546548
if needs_unsquash:
547549
output = output.reshape(shape)
@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
612614
# Single point structure is similar to
613615
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
614616
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
615-
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
617+
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
616618
# 2) Now let's transform the points using affine matrix
617619
transformed_points = torch.matmul(points, transposed_affine_matrix)
618620
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
@@ -797,19 +799,15 @@ def rotate_image_tensor(
797799
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
798800

799801
if image.numel() > 0:
800-
fp = torch.is_floating_point(image)
801802
image = image.reshape(-1, num_channels, height, width)
802803

803804
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
804805

805806
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
806-
dtype = image.dtype if fp else torch.float32
807+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
807808
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
808809
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
809-
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
810-
811-
if not fp:
812-
output = output.round_().to(image.dtype)
810+
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
813811

814812
new_height, new_width = output.shape[-2:]
815813
else:
@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
12371235

12381236
d = 0.5
12391237
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1240-
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device)
1238+
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
12411239
base_grid[..., 0].copy_(x_grid)
1242-
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
1240+
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
12431241
base_grid[..., 1].copy_(y_grid)
12441242
base_grid[..., 2].fill_(1)
12451243

@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
12831281

12841282
shape = image.shape
12851283
ndim = image.ndim
1286-
fp = torch.is_floating_point(image)
12871284

12881285
if ndim > 4:
12891286
image = image.reshape((-1,) + shape[-3:])
@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
13041301
)
13051302

13061303
oh, ow = shape[-2:]
1307-
dtype = image.dtype if fp else torch.float32
1304+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
13081305
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1309-
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
1310-
1311-
if not fp:
1312-
output = output.round_().to(image.dtype)
1306+
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
13131307

13141308
if needs_unsquash:
13151309
output = output.reshape(shape)
@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
14941488

14951489
shape = image.shape
14961490
ndim = image.ndim
1491+
14971492
device = image.device
1498-
fp = torch.is_floating_point(image)
1493+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1494+
# We are aware that if input image dtype is uint8 and displacement is float64 then
1495+
# displacement will be casted to float32 and all computations will be done with float32
1496+
# We can fix this later if needed
14991497

15001498
if ndim > 4:
15011499
image = image.reshape((-1,) + shape[-3:])
@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
15061504
else:
15071505
needs_unsquash = False
15081506

1509-
image_height, image_width = shape[-2:]
1510-
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
1511-
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
1507+
if displacement.dtype != dtype or displacement.device != device:
1508+
displacement = displacement.to(dtype=dtype, device=device)
15121509

1513-
if not fp:
1514-
output = output.round_().to(image.dtype)
1510+
image_height, image_width = shape[-2:]
1511+
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
1512+
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
15151513

15161514
if needs_unsquash:
15171515
output = output.reshape(shape)
@@ -1531,13 +1529,13 @@ def elastic_image_pil(
15311529
return to_pil_image(output, mode=image.mode)
15321530

15331531

1534-
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
1532+
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
15351533
sy, sx = size
1536-
base_grid = torch.empty(1, sy, sx, 2, device=device)
1537-
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
1534+
base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
1535+
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
15381536
base_grid[..., 0].copy_(x_grid)
15391537

1540-
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
1538+
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
15411539
base_grid[..., 1].copy_(y_grid)
15421540

15431541
return base_grid
@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
15521550
return bounding_box
15531551

15541552
# TODO: add in docstring about approximation we are doing for grid inversion
1555-
displacement = displacement.to(bounding_box.device)
1553+
device = bounding_box.device
1554+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
1555+
1556+
if displacement.dtype != dtype or displacement.device != device:
1557+
displacement = displacement.to(dtype=dtype, device=device)
15561558

15571559
original_shape = bounding_box.shape
15581560
bounding_box = (
@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
15631565
# Or add spatial_size arg and check displacement shape
15641566
spatial_size = displacement.shape[-3], displacement.shape[-2]
15651567

1566-
id_grid = _create_identity_grid(spatial_size, bounding_box.device)
1568+
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
15671569
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
15681570
# This is not an exact inverse of the grid
15691571
inv_grid = id_grid.sub_(displacement)

torchvision/transforms/transforms.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,11 @@ def __init__(self, transformation_matrix, mean_vector):
10781078
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
10791079
)
10801080

1081+
if transformation_matrix.dtype != mean_vector.dtype:
1082+
raise ValueError(
1083+
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
1084+
)
1085+
10811086
self.transformation_matrix = transformation_matrix
10821087
self.mean_vector = mean_vector
10831088

@@ -1105,9 +1110,10 @@ def forward(self, tensor: Tensor) -> Tensor:
11051110
)
11061111

11071112
flat_tensor = tensor.view(-1, n) - self.mean_vector
1108-
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
1109-
tensor = transformed_tensor.view(shape)
1110-
return tensor
1113+
1114+
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
1115+
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
1116+
return transformed_tensor.view(shape)
11111117

11121118
def __repr__(self) -> str:
11131119
s = (

0 commit comments

Comments
 (0)