-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fixed issues with dtype in geom functional transforms v2 #7211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
7213e56
4e843d9
3cca1ab
1abb0a8
9004245
f93c743
c5e17db
b2a3071
d02f3dd
c44dc55
a9be544
35f3412
1fd2887
bf46576
d143d33
3e4eff6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -404,9 +404,12 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in | |
|
||
|
||
def _apply_grid_transform( | ||
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT | ||
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT | ||
) -> torch.Tensor: | ||
|
||
fp = img.dtype == grid.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please explain why we're sure that img.dtype is float iff it's the same as the grid dtype? Can't we just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we can use |
||
float_img = img if fp else img.to(grid.dtype) | ||
|
||
shape = float_img.shape | ||
if shape[0] > 1: | ||
# Apply same grid to a batch of images | ||
|
@@ -433,7 +436,9 @@ def _apply_grid_transform( | |
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill | ||
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) | ||
|
||
return float_img | ||
img = float_img.round_().to(img.dtype) if not fp else float_img | ||
|
||
return img | ||
|
||
|
||
def _assert_grid_transform_inputs( | ||
|
@@ -511,7 +516,6 @@ def affine_image_tensor( | |
|
||
shape = image.shape | ||
ndim = image.ndim | ||
fp = torch.is_floating_point(image) | ||
|
||
if ndim > 4: | ||
image = image.reshape((-1,) + shape[-3:]) | ||
|
@@ -535,13 +539,10 @@ def affine_image_tensor( | |
|
||
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) | ||
|
||
dtype = image.dtype if fp else torch.float32 | ||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) | ||
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) | ||
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) | ||
|
||
if not fp: | ||
output = output.round_().to(image.dtype) | ||
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) | ||
|
||
if needs_unsquash: | ||
output = output.reshape(shape) | ||
|
@@ -612,7 +613,7 @@ def _affine_bounding_box_xyxy( | |
# Single point structure is similar to | ||
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] | ||
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) | ||
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) | ||
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) | ||
# 2) Now let's transform the points using affine matrix | ||
transformed_points = torch.matmul(points, transposed_affine_matrix) | ||
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords] | ||
|
@@ -797,19 +798,15 @@ def rotate_image_tensor( | |
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) | ||
|
||
if image.numel() > 0: | ||
fp = torch.is_floating_point(image) | ||
image = image.reshape(-1, num_channels, height, width) | ||
|
||
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) | ||
|
||
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height) | ||
dtype = image.dtype if fp else torch.float32 | ||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) | ||
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh) | ||
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) | ||
|
||
if not fp: | ||
output = output.round_().to(image.dtype) | ||
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) | ||
|
||
new_height, new_width = output.shape[-2:] | ||
else: | ||
|
@@ -1237,9 +1234,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, | |
|
||
d = 0.5 | ||
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) | ||
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device) | ||
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype) | ||
base_grid[..., 0].copy_(x_grid) | ||
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device).unsqueeze_(-1) | ||
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1) | ||
base_grid[..., 1].copy_(y_grid) | ||
base_grid[..., 2].fill_(1) | ||
|
||
|
@@ -1283,7 +1280,6 @@ def perspective_image_tensor( | |
|
||
shape = image.shape | ||
ndim = image.ndim | ||
fp = torch.is_floating_point(image) | ||
|
||
if ndim > 4: | ||
image = image.reshape((-1,) + shape[-3:]) | ||
|
@@ -1304,12 +1300,9 @@ def perspective_image_tensor( | |
) | ||
|
||
oh, ow = shape[-2:] | ||
dtype = image.dtype if fp else torch.float32 | ||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) | ||
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) | ||
|
||
if not fp: | ||
output = output.round_().to(image.dtype) | ||
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) | ||
|
||
if needs_unsquash: | ||
output = output.reshape(shape) | ||
|
@@ -1494,8 +1487,9 @@ def elastic_image_tensor( | |
|
||
shape = image.shape | ||
ndim = image.ndim | ||
|
||
device = image.device | ||
fp = torch.is_floating_point(image) | ||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
|
||
if ndim > 4: | ||
image = image.reshape((-1,) + shape[-3:]) | ||
|
@@ -1506,12 +1500,15 @@ def elastic_image_tensor( | |
else: | ||
needs_unsquash = False | ||
|
||
image_height, image_width = shape[-2:] | ||
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) | ||
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill) | ||
if displacement.dtype != dtype: | ||
displacement = displacement.to(dtype) | ||
|
||
if displacement.device != device: | ||
displacement = displacement.to(device) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not fp: | ||
output = output.round_().to(image.dtype) | ||
image_height, image_width = shape[-2:] | ||
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement) | ||
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) | ||
|
||
if needs_unsquash: | ||
output = output.reshape(shape) | ||
|
@@ -1531,13 +1528,13 @@ def elastic_image_pil( | |
return to_pil_image(output, mode=image.mode) | ||
|
||
|
||
def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor: | ||
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor: | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sy, sx = size | ||
base_grid = torch.empty(1, sy, sx, 2, device=device) | ||
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device) | ||
base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype) | ||
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype) | ||
base_grid[..., 0].copy_(x_grid) | ||
|
||
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1) | ||
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1) | ||
base_grid[..., 1].copy_(y_grid) | ||
|
||
return base_grid | ||
|
@@ -1552,7 +1549,14 @@ def elastic_bounding_box( | |
return bounding_box | ||
|
||
# TODO: add in docstring about approximation we are doing for grid inversion | ||
displacement = displacement.to(bounding_box.device) | ||
device = bounding_box.device | ||
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 | ||
|
||
if displacement.dtype != dtype: | ||
displacement = displacement.to(dtype) | ||
|
||
if displacement.device != device: | ||
displacement = displacement.to(device) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
original_shape = bounding_box.shape | ||
bounding_box = ( | ||
|
@@ -1563,7 +1567,7 @@ def elastic_bounding_box( | |
# Or add spatial_size arg and check displacement shape | ||
spatial_size = displacement.shape[-3], displacement.shape[-2] | ||
|
||
id_grid = _create_identity_grid(spatial_size, bounding_box.device) | ||
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) | ||
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement | ||
# This is not an exact inverse of the grid | ||
inv_grid = id_grid.sub_(displacement) | ||
|
Uh oh!
There was an error while loading. Please reload this page.