Skip to content

Commit b621e38

Browse files
authored
Remove p-value checks in test_transforms.py (#4756)
* Change test_random_apply * Change test_random_choice * Change test_randomness * took care of RandomVert/HorizFlip * take care of RandomGrayScale * minor cleanup * avoid 0 degree rotation just in case
1 parent bbfda42 commit b621e38

File tree

2 files changed

+45
-230
lines changed

2 files changed

+45
-230
lines changed

test/test_transforms.py

Lines changed: 44 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import os
33
import random
4+
from functools import partial
45

56
import numpy as np
67
import pytest
@@ -541,38 +542,35 @@ def test_pad_with_mode_F_images(self):
541542
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])
542543

543544

544-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
545545
@pytest.mark.parametrize(
546-
"fn, trans, config",
546+
"fn, trans, kwargs",
547547
[
548548
(F.invert, transforms.RandomInvert, {}),
549549
(F.posterize, transforms.RandomPosterize, {"bits": 4}),
550550
(F.solarize, transforms.RandomSolarize, {"threshold": 192}),
551551
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
552552
(F.autocontrast, transforms.RandomAutocontrast, {}),
553553
(F.equalize, transforms.RandomEqualize, {}),
554+
(F.vflip, transforms.RandomVerticalFlip, {}),
555+
(F.hflip, transforms.RandomHorizontalFlip, {}),
556+
(partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}),
554557
],
555558
)
556-
@pytest.mark.parametrize("p", (0.5, 0.7))
557-
def test_randomness(fn, trans, config, p):
558-
random_state = random.getstate()
559-
random.seed(42)
559+
@pytest.mark.parametrize("seed", range(10))
560+
@pytest.mark.parametrize("p", (0, 1))
561+
def test_randomness(fn, trans, kwargs, seed, p):
562+
torch.manual_seed(seed)
560563
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
561564

562-
inv_img = fn(img, **config)
565+
expected_transformed_img = fn(img, **kwargs)
566+
randomly_transformed_img = trans(p=p, **kwargs)(img)
563567

564-
num_samples = 250
565-
counts = 0
566-
for _ in range(num_samples):
567-
tranformation = trans(p=p, **config)
568-
tranformation.__repr__()
569-
out = tranformation(img)
570-
if out == inv_img:
571-
counts += 1
568+
if p == 0:
569+
assert randomly_transformed_img == img
570+
elif p == 1:
571+
assert randomly_transformed_img == expected_transformed_img
572572

573-
p_value = stats.binom_test(counts, num_samples, p=p)
574-
random.setstate(random_state)
575-
assert p_value > 0.0001
573+
trans(**kwargs).__repr__()
576574

577575

578576
class TestToPil:
@@ -1362,160 +1360,42 @@ def test_to_grayscale():
13621360
trans4.__repr__()
13631361

13641362

1365-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1366-
def test_random_grayscale():
1367-
"""Unit tests for random grayscale transform"""
1363+
@pytest.mark.parametrize("seed", range(10))
1364+
@pytest.mark.parametrize("p", (0, 1))
1365+
def test_random_apply(p, seed):
1366+
torch.manual_seed(seed)
1367+
random_apply_transform = transforms.RandomApply([transforms.RandomRotation((1, 45))], p=p)
1368+
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
1369+
out = random_apply_transform(img)
1370+
if p == 0:
1371+
assert out == img
1372+
elif p == 1:
1373+
assert out != img
13681374

1369-
# Test Set 1: RGB -> 3 channel grayscale
1370-
np_rng = np.random.RandomState(0)
1371-
random_state = random.getstate()
1372-
random.seed(42)
1373-
x_shape = [2, 2, 3]
1374-
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
1375-
x_pil = Image.fromarray(x_np, mode="RGB")
1376-
x_pil_2 = x_pil.convert("L")
1377-
gray_np = np.array(x_pil_2)
1378-
1379-
num_samples = 250
1380-
num_gray = 0
1381-
for _ in range(num_samples):
1382-
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
1383-
gray_np_2 = np.array(gray_pil_2)
1384-
if (
1385-
np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
1386-
and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
1387-
and np.array_equal(gray_np, gray_np_2[:, :, 0])
1388-
):
1389-
num_gray = num_gray + 1
1390-
1391-
p_value = stats.binom_test(num_gray, num_samples, p=0.5)
1392-
random.setstate(random_state)
1393-
assert p_value > 0.0001
1394-
1395-
# Test Set 2: grayscale -> 1 channel grayscale
1396-
random_state = random.getstate()
1397-
random.seed(42)
1398-
x_shape = [2, 2, 3]
1399-
x_np = np_rng.randint(0, 256, x_shape, np.uint8)
1400-
x_pil = Image.fromarray(x_np, mode="RGB")
1401-
x_pil_2 = x_pil.convert("L")
1402-
gray_np = np.array(x_pil_2)
1403-
1404-
num_samples = 250
1405-
num_gray = 0
1406-
for _ in range(num_samples):
1407-
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
1408-
gray_np_3 = np.array(gray_pil_3)
1409-
if np.array_equal(gray_np, gray_np_3):
1410-
num_gray = num_gray + 1
1411-
1412-
p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged
1413-
random.setstate(random_state)
1414-
assert p_value > 0.0001
1415-
1416-
# Test set 3: Explicit tests
1417-
x_shape = [2, 2, 3]
1418-
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1419-
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1420-
x_pil = Image.fromarray(x_np, mode="RGB")
1421-
x_pil_2 = x_pil.convert("L")
1422-
gray_np = np.array(x_pil_2)
1423-
1424-
# Case 3a: RGB -> 3 channel grayscale (grayscaled)
1425-
trans2 = transforms.RandomGrayscale(p=1.0)
1426-
gray_pil_2 = trans2(x_pil)
1427-
gray_np_2 = np.array(gray_pil_2)
1428-
assert gray_pil_2.mode == "RGB", "mode should be RGB"
1429-
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
1430-
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
1431-
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
1432-
assert_equal(gray_np, gray_np_2[:, :, 0])
1433-
1434-
# Case 3b: RGB -> 3 channel grayscale (unchanged)
1435-
trans2 = transforms.RandomGrayscale(p=0.0)
1436-
gray_pil_2 = trans2(x_pil)
1437-
gray_np_2 = np.array(gray_pil_2)
1438-
assert gray_pil_2.mode == "RGB", "mode should be RGB"
1439-
assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
1440-
assert_equal(x_np, gray_np_2)
1441-
1442-
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
1443-
trans3 = transforms.RandomGrayscale(p=1.0)
1444-
gray_pil_3 = trans3(x_pil_2)
1445-
gray_np_3 = np.array(gray_pil_3)
1446-
assert gray_pil_3.mode == "L", "mode should be L"
1447-
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
1448-
assert_equal(gray_np, gray_np_3)
1449-
1450-
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
1451-
trans3 = transforms.RandomGrayscale(p=0.0)
1452-
gray_pil_3 = trans3(x_pil_2)
1453-
gray_np_3 = np.array(gray_pil_3)
1454-
assert gray_pil_3.mode == "L", "mode should be L"
1455-
assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
1456-
assert_equal(gray_np, gray_np_3)
1375+
# Checking if RandomApply can be printed as string
1376+
random_apply_transform.__repr__()
14571377

1458-
# Checking if RandomGrayscale can be printed as string
1459-
trans3.__repr__()
14601378

1379+
@pytest.mark.parametrize("seed", range(10))
1380+
@pytest.mark.parametrize("proba_passthrough", (0, 1))
1381+
def test_random_choice(proba_passthrough, seed):
1382+
random.seed(seed) # RandomChoice relies on python builtin random.choice, not pytorch
14611383

1462-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1463-
def test_random_apply():
1464-
random_state = random.getstate()
1465-
random.seed(42)
1466-
random_apply_transform = transforms.RandomApply(
1384+
random_choice_transform = transforms.RandomChoice(
14671385
[
1468-
transforms.RandomRotation((-45, 45)),
1469-
transforms.RandomHorizontalFlip(),
1470-
transforms.RandomVerticalFlip(),
1386+
lambda x: x, # passthrough
1387+
transforms.RandomRotation((1, 45)),
14711388
],
1472-
p=0.75,
1389+
p=[proba_passthrough, 1 - proba_passthrough],
14731390
)
1474-
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
1475-
num_samples = 250
1476-
num_applies = 0
1477-
for _ in range(num_samples):
1478-
out = random_apply_transform(img)
1479-
if out != img:
1480-
num_applies += 1
1481-
1482-
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
1483-
random.setstate(random_state)
1484-
assert p_value > 0.0001
1485-
1486-
# Checking if RandomApply can be printed as string
1487-
random_apply_transform.__repr__()
1488-
14891391

1490-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1491-
def test_random_choice():
1492-
random_state = random.getstate()
1493-
random.seed(42)
1494-
random_choice_transform = transforms.RandomChoice(
1495-
[transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3]
1496-
)
1497-
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
1498-
num_samples = 250
1499-
num_resize_15 = 0
1500-
num_resize_20 = 0
1501-
num_crop_10 = 0
1502-
for _ in range(num_samples):
1503-
out = random_choice_transform(img)
1504-
if out.size == (15, 15):
1505-
num_resize_15 += 1
1506-
elif out.size == (20, 20):
1507-
num_resize_20 += 1
1508-
elif out.size == (10, 10):
1509-
num_crop_10 += 1
1510-
1511-
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
1512-
assert p_value > 0.0001
1513-
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
1514-
assert p_value > 0.0001
1515-
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
1516-
assert p_value > 0.0001
1392+
img = transforms.ToPILImage()(torch.rand(3, 30, 40))
1393+
out = random_choice_transform(img)
1394+
if proba_passthrough == 1:
1395+
assert out == img
1396+
elif proba_passthrough == 0:
1397+
assert out != img
15171398

1518-
random.setstate(random_state)
15191399
# Checking if RandomChoice can be printed as string
15201400
random_choice_transform.__repr__()
15211401

@@ -1888,6 +1768,7 @@ def test_random_erasing():
18881768
tol = 0.05
18891769
assert 1 / 3 - tol <= aspect_ratio <= 3 + tol
18901770

1771+
# Make sure that h > w and h < w are equaly likely (log-scale sampling)
18911772
aspect_ratios = []
18921773
random.seed(42)
18931774
trial = 1000
@@ -2011,72 +1892,6 @@ def test_randomperspective_fill(mode):
20111892
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
20121893

20131894

2014-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
2015-
def test_random_vertical_flip():
2016-
random_state = random.getstate()
2017-
random.seed(42)
2018-
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
2019-
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
2020-
2021-
num_samples = 250
2022-
num_vertical = 0
2023-
for _ in range(num_samples):
2024-
out = transforms.RandomVerticalFlip()(img)
2025-
if out == vimg:
2026-
num_vertical += 1
2027-
2028-
p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
2029-
random.setstate(random_state)
2030-
assert p_value > 0.0001
2031-
2032-
num_samples = 250
2033-
num_vertical = 0
2034-
for _ in range(num_samples):
2035-
out = transforms.RandomVerticalFlip(p=0.7)(img)
2036-
if out == vimg:
2037-
num_vertical += 1
2038-
2039-
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
2040-
random.setstate(random_state)
2041-
assert p_value > 0.0001
2042-
2043-
# Checking if RandomVerticalFlip can be printed as string
2044-
transforms.RandomVerticalFlip().__repr__()
2045-
2046-
2047-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
2048-
def test_random_horizontal_flip():
2049-
random_state = random.getstate()
2050-
random.seed(42)
2051-
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
2052-
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
2053-
2054-
num_samples = 250
2055-
num_horizontal = 0
2056-
for _ in range(num_samples):
2057-
out = transforms.RandomHorizontalFlip()(img)
2058-
if out == himg:
2059-
num_horizontal += 1
2060-
2061-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
2062-
random.setstate(random_state)
2063-
assert p_value > 0.0001
2064-
2065-
num_samples = 250
2066-
num_horizontal = 0
2067-
for _ in range(num_samples):
2068-
out = transforms.RandomHorizontalFlip(p=0.7)(img)
2069-
if out == himg:
2070-
num_horizontal += 1
2071-
2072-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
2073-
random.setstate(random_state)
2074-
assert p_value > 0.0001
2075-
2076-
# Checking if RandomHorizontalFlip can be printed as string
2077-
transforms.RandomHorizontalFlip().__repr__()
2078-
2079-
20801895
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
20811896
def test_normalize():
20821897
def samples_from_standard_normal(tensor):

torchvision/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ class RandomChoice(RandomTransforms):
562562
def __init__(self, transforms, p=None):
563563
super().__init__(transforms)
564564
if p is not None and not isinstance(p, Sequence):
565-
raise TypeError("Argument transforms should be a sequence")
565+
raise TypeError("Argument p should be a sequence")
566566
self.p = p
567567

568568
def __call__(self, *args):

0 commit comments

Comments
 (0)