Skip to content

Make F.rotate/F.affine accept learnable params #5110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: learnable_params
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ def test_rotate_interpolation_type(self):
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)

@pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
@pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)])
def test_differentiable_rotate(self, fn, center):
alpha = torch.tensor(1.0, requires_grad=True)
x = torch.zeros(1, 3, 10, 10)
x[0, :, 2:5, 2:5] = 1.0

y = fn(x, alpha, interpolation=BILINEAR, center=center)
assert y.requires_grad
y.mean().backward()
assert alpha.grad is not None
if center is not None:
assert center.grad is not None

@pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)])
def test_differentiable_rotate_nonfloat(self, center):
alpha = torch.tensor(1.0, requires_grad=True)
x = torch.zeros(1, 3, 10, 10, dtype=torch.long)
x[0, :, 2:5, 2:5] = 1

with pytest.raises(ValueError, match=r"input should be float tensor"):
F.rotate(x, alpha, interpolation=BILINEAR, center=center)


class TestAffine:

Expand Down Expand Up @@ -379,6 +402,37 @@ def test_warnings(self, device):
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
assert_equal(np.asarray(res1), np.asarray(res2))

@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
@pytest.mark.parametrize("translate", [[0, 0], torch.tensor([1.0, 2.0], requires_grad=True)])
@pytest.mark.parametrize("scale", [1.0, torch.tensor(1.0, requires_grad=True)])
@pytest.mark.parametrize("shear", [[1.0, 1.0], torch.tensor([1.0, 1.0], requires_grad=True)])
def test_differentiable_affine(self, fn, translate, scale, shear):
alpha = torch.tensor(1.0, requires_grad=True)
x = torch.zeros(1, 3, 10, 10)
x[0, :, 2:5, 2:5] = 1.0

y = fn(x, alpha, translate, scale, shear, interpolation=BILINEAR)
assert y.requires_grad
y.mean().backward()
assert alpha.grad is not None
if isinstance(translate, torch.Tensor):
assert translate.grad is not None
if isinstance(scale, torch.Tensor):
assert scale.grad is not None
if isinstance(shear, torch.Tensor):
assert shear.grad is not None

@pytest.mark.parametrize("translate", [[0, 0], torch.tensor([1.0, 2.0], requires_grad=True)])
@pytest.mark.parametrize("scale", [1.0, torch.tensor(1.0, requires_grad=True)])
@pytest.mark.parametrize("shear", [[1.0, 1.0], torch.tensor([1.0, 1.0], requires_grad=True)])
def test_differentiable_affine_nonfloat(self, translate, scale, shear):
alpha = torch.tensor(1.0, requires_grad=True)
x = torch.zeros(1, 3, 10, 10, dtype=torch.long)
x[0, :, 2:5, 2:5] = 1

with pytest.raises(ValueError, match=r"input should be float tensor"):
F.affine(x, alpha, translate, scale, shear, interpolation=BILINEAR)


def _get_data_dims_and_points_for_perspective():
# Ideally we would parametrize independently over data dims and points, but
Expand Down
8 changes: 7 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,7 +2015,13 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_
true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))

result_matrix = self._to_3x3_inv(
F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear)
F._get_inverse_affine_matrix_tensor(
center=torch.tensor(cnt, dtype=torch.float64), # using double to match true_matrix precision
angle=torch.tensor(angle, dtype=torch.float64),
translate=torch.tensor(translate, dtype=torch.float64),
scale=torch.tensor(scale, dtype=torch.float64),
shear=torch.tensor(shear, dtype=torch.float64),
)
)
assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
# 2) Perform inverse mapping:
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def test_resized_crop_save(self, tmpdir):


def _test_random_affine_helper(device, **kwargs):
torch.manual_seed(12)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this line and fix shear value due to some random jit vs non-jit results mismatch in one pixel

tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandomAffine(**kwargs)
Expand All @@ -482,7 +483,7 @@ def test_random_affine(device, tmpdir):

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 11.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
def test_random_affine_shear(device, interpolation, shear):
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)

Expand Down
Loading