Skip to content

Commit a61f719

Browse files
committed
Remove private doc from invert, create or reuse generic testing methods to avoid duplication of code in the tests.
1 parent 10c3efa commit a61f719

File tree

4 files changed

+30
-62
lines changed

4 files changed

+30
-62
lines changed

test/test_functional_tensor.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -863,19 +863,14 @@ def test_gaussian_blur(self):
863863
)
864864

865865
def test_invert(self):
866-
script_invert = torch.jit.script(F.invert)
867-
868-
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
869-
inverted_img = F.invert(img_tensor)
870-
inverted_pil_img = F.invert(pil_img)
871-
self.compareTensorToPIL(inverted_img, inverted_pil_img)
872-
873-
# scriptable function test
874-
inverted_img_script = script_invert(img_tensor)
875-
self.assertTrue(inverted_img.equal(inverted_img_script))
876-
877-
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
878-
self._test_fn_on_batch(batch_tensors, F.invert)
866+
self._test_adjust_fn(
867+
F.invert,
868+
F_pil.invert,
869+
F_t.invert,
870+
[{}],
871+
tol=1.0,
872+
agg_method="max"
873+
)
879874

880875

881876
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")

test/test_transforms.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,37 +1749,36 @@ def test_gaussian_blur_asserts(self):
17491749
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
17501750
transforms.GaussianBlur(3, "sigma_string")
17511751

1752-
@unittest.skipIf(stats is None, 'scipy.stats not available')
1753-
def test_random_invert(self):
1752+
def _test_randomness(self, fn, trans, configs):
17541753
random_state = random.getstate()
17551754
random.seed(42)
17561755
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
1757-
inv_img = F.invert(img)
17581756

1759-
num_samples = 250
1760-
num_inverts = 0
1761-
for _ in range(num_samples):
1762-
out = transforms.RandomInvert()(img)
1763-
if out == inv_img:
1764-
num_inverts += 1
1757+
for p in [0.5, 0.7]:
1758+
for config in configs:
1759+
inv_img = fn(img, **config)
17651760

1766-
p_value = stats.binom_test(num_inverts, num_samples, p=0.5)
1767-
random.setstate(random_state)
1768-
self.assertGreater(p_value, 0.0001)
1761+
num_samples = 250
1762+
counts = 0
1763+
for _ in range(num_samples):
1764+
out = trans(p=p, **config)(img)
1765+
if out == inv_img:
1766+
counts += 1
17691767

1770-
num_samples = 250
1771-
num_inverts = 0
1772-
for _ in range(num_samples):
1773-
out = transforms.RandomInvert(p=0.7)(img)
1774-
if out == inv_img:
1775-
num_inverts += 1
1768+
p_value = stats.binom_test(counts, num_samples, p=p)
1769+
random.setstate(random_state)
1770+
self.assertGreater(p_value, 0.0001)
17761771

1777-
p_value = stats.binom_test(num_inverts, num_samples, p=0.7)
1778-
random.setstate(random_state)
1779-
self.assertGreater(p_value, 0.0001)
1772+
# Checking if it can be printed as string
1773+
trans().__repr__()
17801774

1781-
# Checking if RandomInvert can be printed as string
1782-
transforms.RandomInvert().__repr__()
1775+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1776+
def test_random_invert(self):
1777+
self._test_randomness(
1778+
F.invert,
1779+
transforms.RandomInvert,
1780+
[{}]
1781+
)
17831782

17841783

17851784
if __name__ == '__main__':

torchvision/transforms/functional_pil.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -610,19 +610,6 @@ def to_grayscale(img, num_output_channels):
610610

611611
@torch.jit.unused
612612
def invert(img):
613-
"""PRIVATE METHOD. Invert the colors of an image.
614-
615-
.. warning::
616-
617-
Module ``transforms.functional_pil`` is private and should not be used in user application.
618-
Please, consider instead using methods from `transforms.functional` module.
619-
620-
Args:
621-
img (PIL Image): Image to have its colors inverted.
622-
623-
Returns:
624-
PIL Image: Color inverted image Tensor.
625-
"""
626613
if not _is_pil_image(img):
627614
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
628615
return ImageOps.invert(img)

torchvision/transforms/functional_tensor.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,19 +1182,6 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11821182

11831183

11841184
def invert(img: Tensor) -> Tensor:
1185-
"""PRIVATE METHOD. Invert the colors of a grayscale or RGB image.
1186-
1187-
.. warning::``
1188-
1189-
Module ``transforms.functional_tensor`` is private and should not be used in user application.
1190-
Please, consider instead using methods from `transforms.functional` module.
1191-
1192-
Args:
1193-
img (Tensor): Image to have its colors inverted in the form [C, H, W].
1194-
1195-
Returns:
1196-
Tensor: Color inverted image Tensor.
1197-
"""
11981185
if not _is_tensor_a_torch_image(img):
11991186
raise TypeError('tensor is not a torch image.')
12001187

0 commit comments

Comments
 (0)