Skip to content

[proto] Added functional rotate_segmentation_mask op #5692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
234f113
Added functional affine_bounding_box op with tests
vfdev-5 Mar 8, 2022
a24fca7
Updated comments and added another test case
vfdev-5 Mar 14, 2022
17ebc0b
Merge branch 'main' into proto-bbox-affine
vfdev-5 Mar 14, 2022
a872483
Update _geometry.py
vfdev-5 Mar 14, 2022
1fc2b44
Merge branch 'main' into proto-bbox-affine
vfdev-5 Mar 14, 2022
7ab7d8a
Added affine_segmentation_mask with tests
vfdev-5 Mar 14, 2022
36ed30a
Fixed device mismatch issue
vfdev-5 Mar 14, 2022
d08d335
Merge branch 'main' into proto-bbox-affine
vfdev-5 Mar 14, 2022
2ca39b0
Merge branch 'proto-bbox-affine' of github.com:vfdev-5/vision into pr…
vfdev-5 Mar 14, 2022
3a277a8
Merge branch 'main' into proto-mask-affine
vfdev-5 Mar 15, 2022
d003051
Added test_correctness_affine_segmentation_mask_on_fixed_input
vfdev-5 Mar 15, 2022
07f0966
Merge branch 'main' of github.com:pytorch/vision into proto-mask-affine
vfdev-5 Mar 15, 2022
7e89062
Updates according to the review
vfdev-5 Mar 16, 2022
acb996a
Merge branch 'main' into proto-mask-affine
vfdev-5 Mar 16, 2022
3010f32
Merge branch 'main' of github.com:pytorch/vision into proto-mask-affine
vfdev-5 Mar 21, 2022
a2be666
Replaced [None, ...] by [None, :]
vfdev-5 Mar 21, 2022
96fb852
Merge branch 'main' of github.com:pytorch/vision into proto-mask-affine
vfdev-5 Mar 23, 2022
9d6ac74
Adressed review comments
vfdev-5 Mar 23, 2022
d17decb
Fixed formatting and more updates according to the review
vfdev-5 Mar 23, 2022
6d43f4a
Merge branch 'main' into proto-mask-affine
vfdev-5 Mar 23, 2022
f4c2243
Fixed bad merge
vfdev-5 Mar 23, 2022
cd317a9
Merge branch 'main' of github.com:pytorch/vision into proto-mask-rotate
vfdev-5 Mar 23, 2022
6037610
WIP
vfdev-5 Mar 25, 2022
fb15186
Merge branch 'main' of github.com:pytorch/vision into proto-mask-rotate
vfdev-5 Mar 25, 2022
8cb3510
Fixed tests
vfdev-5 Mar 28, 2022
85abb24
Merge branch 'main' of github.com:pytorch/vision into proto-mask-rotate
vfdev-5 Mar 28, 2022
2daefa6
Merge branch 'main' of github.com:pytorch/vision into proto-mask-rotate
vfdev-5 Apr 4, 2022
ee9f3d6
Updated warning message
vfdev-5 Apr 4, 2022
e0f7663
Merge branch 'main' into proto-mask-rotate
vfdev-5 Apr 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 130 additions & 15 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ def affine_bounding_box():

@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
for image, angle, translate, scale, shear in itertools.product(
for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
[0, 12], # shear
):
yield SampleInput(
image,
mask,
angle=angle,
translate=(translate, translate),
scale=scale,
Expand All @@ -285,8 +285,12 @@ def affine_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn
def rotate_bounding_box():
for bounding_box, angle, expand, center in itertools.product(
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]]
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue

yield SampleInput(
bounding_box,
format=bounding_box.format,
Expand All @@ -297,6 +301,26 @@ def rotate_bounding_box():
)


@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue

yield SampleInput(
mask,
angle=angle,
expand=expand,
center=center,
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -411,8 +435,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
center=center,
)

if center is None:
center = [s // 2 for s in bboxes_image_size[::-1]]
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]

if bboxes.ndim < 2:
bboxes = [bboxes]
Expand All @@ -421,7 +446,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center)
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_)
)
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
Expand Down Expand Up @@ -510,8 +535,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
shear=(shear, shear),
center=center,
)
if center is None:
center = [s // 2 for s in mask.shape[-2:][::-1]]

center_ = center
if center_ is None:
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]

if mask.ndim < 4:
masks = [mask]
Expand All @@ -520,7 +547,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):

expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center)
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
Expand Down Expand Up @@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):


@pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize("center", [None, (12, 14)])
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_bounding_box(angle, expand, center):
def _compute_expected_bbox(bbox, angle_, expand_, center_):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
Expand Down Expand Up @@ -620,16 +646,17 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
center=center,
)

if center is None:
center = [s // 2 for s in bboxes_image_size[::-1]]
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]

if bboxes.ndim < 2:
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center))
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
Expand All @@ -638,7 +665,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output
image_size = (64, 64)
Expand Down Expand Up @@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
)

torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)


@pytest.mark.parametrize("angle", range(-90, 90, 37))
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_segmentation_mask(angle, expand, center):
def _compute_expected_mask(mask, angle_, expand_, center_):
assert mask.ndim == 3 and mask.shape[0] == 1
image_size = mask.shape[-2:]
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
inv_affine_matrix = np.linalg.inv(affine_matrix)

if expand_:
# Pillow implementation on how to perform expand:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069
height, width = image_size
points = np.array(
[
[0.0, 0.0, 1.0],
[0.0, 1.0 * height, 1.0],
[1.0 * width, 1.0 * height, 1.0],
[1.0 * width, 0.0, 1.0],
]
)
new_points = points @ inv_affine_matrix.T
min_vals = np.min(new_points, axis=0)[:2]
max_vals = np.max(new_points, axis=0)[:2]
cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4)
cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4)
new_width, new_height = (cmax - cmin).astype("int32").tolist()
tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T

inv_affine_matrix[:2, 2] = tr[:2]
image_size = [new_height, new_width]

inv_affine_matrix = inv_affine_matrix[:2, :]
expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype)

for out_y in range(expected_mask.shape[1]):
for out_x in range(expected_mask.shape[2]):
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
in_x, in_y = input_pt[:2]
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
return expected_mask.to(mask.device)

for mask in make_segmentation_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask(
mask,
angle=angle,
expand=expand,
center=center,
)

center_ = center
if center_ is None:
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]

if mask.ndim < 4:
masks = [mask]
else:
masks = [m for m in mask]

expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, -angle, expand, center_)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
else:
expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks)


@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
# Check transformation against known expected output and CPU/CUDA devices

# Create a fixed input segmentation mask with 2 square masks
# in top-left, bottom-left corners
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
mask[0, 2:10, 2:10] = 1
mask[0, 32 - 9 : 32 - 3, 3:9] = 2

# Rotate 90 degrees
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
rotate_bounding_box,
rotate_image_tensor,
rotate_image_pil,
rotate_segmentation_mask,
pad_image_tensor,
pad_image_pil,
pad_bounding_box,
Expand Down
23 changes: 21 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def rotate_image_tensor(
center_f = [0.0, 0.0]
if center is not None:
if expand:
warnings.warn("The provided center argument is ignored if expand is True")
warnings.warn("The provided center argument has no effect on the result if expand is True")
else:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
Expand All @@ -345,7 +345,7 @@ def rotate_image_pil(
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
warnings.warn("The provided center argument is ignored if expand is True")
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None

return _FP.rotate(
Expand All @@ -361,6 +361,10 @@ def rotate_bounding_box(
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None

original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
Expand All @@ -373,6 +377,21 @@ def rotate_bounding_box(
).view(original_shape)


def rotate_segmentation_mask(
img: torch.Tensor,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return rotate_image_tensor(
img,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
center=center,
)


pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad

Expand Down