Skip to content

Commit aade017

Browse files
datumboxvfdev-5fmassa
authored andcommitted
Implement all AutoAugment transforms + Policies (#3123)
Summary: * Invert Transform (#3104) * Adding invert operator. * Make use of the _assert_channels(). * Update upper bound value. * Remove private doc from invert, create or reuse generic testing methods to avoid duplication of code in the tests. (#3106) * Create posterize transformation and refactor common methods to assist reuse. (#3108) * Implement the solarize transform. (#3112) * Implement the adjust_sharpness transform (#3114) * Adding functional operator for sharpness. * Adding transforms for sharpness. * Handling tiny images and adding a test. * Implement the autocontrast transform. (#3117) * Implement the equalize transform (#3119) * Implement the equalize transform. * Turn off deterministic for histogram. * Fixing test. (#3126) * Force ratio to be float to avoid numeric overflows on blend. (#3127) * Separate the tests of Adjust Sharpness from ColorJitter. (#3128) * Add AutoAugment Policies and main Transform (#3142) * Separate the tests of Adjust Sharpness from ColorJitter. * Initial implementation, not-jitable. * AutoAugment passing JIT. * Adding tests/docs, changing formatting. * Update test. * Fix formats * Fix documentation and imports. * Apply changes from code review: - Move the transformations outside of AutoAugment on a separate method. - Renamed degenerate method for sharpness for better clarity. * Update torchvision/transforms/functional.py * Apply more changes from code review: - Add InterpolationMode parameter. - Move all declarations away from AutoAugment constructor and into the private method. * Update documentation. * Apply suggestions from code review * Apply changes from code review: - Refactor code to eliminate as any to() and clamp() as possible. - Reuse methods where possible. - Apply speed ups. * Replacing pad. Reviewed By: fmassa Differential Revision: D25679210 fbshipit-source-id: f7b4a086dc9479e44f93e508d6070280cbc9bdac Co-authored-by: vfdev <[email protected]> Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: vfdev <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent 637ea80 commit aade017

9 files changed

+968
-4
lines changed

test/test_functional_tensor.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,14 @@ def test_pad(self):
289289

290290
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
291291

292-
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
292+
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
293+
dts=(None, torch.float32, torch.float64)):
293294
script_fn = torch.jit.script(fn)
294295
torch.manual_seed(15)
295296
tensor, pil_img = self._create_data(26, 34, device=self.device)
296297
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
297298

298-
for dt in [None, torch.float32, torch.float64]:
299+
for dt in dts:
299300

300301
if dt is not None:
301302
tensor = F.convert_image_dtype(tensor, dt)
@@ -862,6 +863,77 @@ def test_gaussian_blur(self):
862863
msg="{}, {}".format(ksize, sigma)
863864
)
864865

866+
def test_invert(self):
867+
self._test_adjust_fn(
868+
F.invert,
869+
F_pil.invert,
870+
F_t.invert,
871+
[{}],
872+
tol=1.0,
873+
agg_method="max"
874+
)
875+
876+
def test_posterize(self):
877+
self._test_adjust_fn(
878+
F.posterize,
879+
F_pil.posterize,
880+
F_t.posterize,
881+
[{"bits": bits} for bits in range(0, 8)],
882+
tol=1.0,
883+
agg_method="max",
884+
dts=(None,)
885+
)
886+
887+
def test_solarize(self):
888+
self._test_adjust_fn(
889+
F.solarize,
890+
F_pil.solarize,
891+
F_t.solarize,
892+
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
893+
tol=1.0,
894+
agg_method="max",
895+
dts=(None,)
896+
)
897+
self._test_adjust_fn(
898+
F.solarize,
899+
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
900+
F_t.solarize,
901+
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
902+
tol=1.0,
903+
agg_method="max",
904+
dts=(torch.float32, torch.float64)
905+
)
906+
907+
def test_adjust_sharpness(self):
908+
self._test_adjust_fn(
909+
F.adjust_sharpness,
910+
F_pil.adjust_sharpness,
911+
F_t.adjust_sharpness,
912+
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
913+
)
914+
915+
def test_autocontrast(self):
916+
self._test_adjust_fn(
917+
F.autocontrast,
918+
F_pil.autocontrast,
919+
F_t.autocontrast,
920+
[{}],
921+
tol=1.0,
922+
agg_method="max"
923+
)
924+
925+
def test_equalize(self):
926+
torch.set_deterministic(False)
927+
self._test_adjust_fn(
928+
F.equalize,
929+
F_pil.equalize,
930+
F_t.equalize,
931+
[{}],
932+
tol=1.0,
933+
agg_method="max",
934+
dts=(None,)
935+
)
936+
865937

866938
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
867939
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,48 @@ def test_adjust_hue(self):
12341234
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
12351235
self.assertTrue(np.allclose(y_np, y_ans))
12361236

1237+
def test_adjust_sharpness(self):
1238+
x_shape = [4, 4, 3]
1239+
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1240+
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
1241+
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1242+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1243+
x_pil = Image.fromarray(x_np, mode='RGB')
1244+
1245+
# test 0
1246+
y_pil = F.adjust_sharpness(x_pil, 1)
1247+
y_np = np.array(y_pil)
1248+
self.assertTrue(np.allclose(y_np, x_np))
1249+
1250+
# test 1
1251+
y_pil = F.adjust_sharpness(x_pil, 0.5)
1252+
y_np = np.array(y_pil)
1253+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
1254+
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
1255+
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1256+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1257+
self.assertTrue(np.allclose(y_np, y_ans))
1258+
1259+
# test 2
1260+
y_pil = F.adjust_sharpness(x_pil, 2)
1261+
y_np = np.array(y_pil)
1262+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1263+
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
1264+
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1265+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1266+
self.assertTrue(np.allclose(y_np, y_ans))
1267+
1268+
# test 3
1269+
x_shape = [2, 2, 3]
1270+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1271+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1272+
x_pil = Image.fromarray(x_np, mode='RGB')
1273+
x_th = torch.tensor(x_np.transpose(2, 0, 1))
1274+
y_pil = F.adjust_sharpness(x_pil, 2)
1275+
y_np = np.array(y_pil).transpose(2, 0, 1)
1276+
y_th = F.adjust_sharpness(x_th, 2)
1277+
self.assertTrue(np.allclose(y_np, y_th.numpy()))
1278+
12371279
def test_adjust_gamma(self):
12381280
x_shape = [2, 2, 3]
12391281
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
@@ -1270,6 +1312,7 @@ def test_adjusts_L_mode(self):
12701312
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
12711313
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
12721314
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
1315+
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
12731316
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
12741317

12751318
def test_color_jitter(self):
@@ -1751,6 +1794,86 @@ def test_gaussian_blur_asserts(self):
17511794
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
17521795
transforms.GaussianBlur(3, "sigma_string")
17531796

1797+
def _test_randomness(self, fn, trans, configs):
1798+
random_state = random.getstate()
1799+
random.seed(42)
1800+
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
1801+
1802+
for p in [0.5, 0.7]:
1803+
for config in configs:
1804+
inv_img = fn(img, **config)
1805+
1806+
num_samples = 250
1807+
counts = 0
1808+
for _ in range(num_samples):
1809+
tranformation = trans(p=p, **config)
1810+
tranformation.__repr__()
1811+
out = tranformation(img)
1812+
if out == inv_img:
1813+
counts += 1
1814+
1815+
p_value = stats.binom_test(counts, num_samples, p=p)
1816+
random.setstate(random_state)
1817+
self.assertGreater(p_value, 0.0001)
1818+
1819+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1820+
def test_random_invert(self):
1821+
self._test_randomness(
1822+
F.invert,
1823+
transforms.RandomInvert,
1824+
[{}]
1825+
)
1826+
1827+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1828+
def test_random_posterize(self):
1829+
self._test_randomness(
1830+
F.posterize,
1831+
transforms.RandomPosterize,
1832+
[{"bits": 4}]
1833+
)
1834+
1835+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1836+
def test_random_solarize(self):
1837+
self._test_randomness(
1838+
F.solarize,
1839+
transforms.RandomSolarize,
1840+
[{"threshold": 192}]
1841+
)
1842+
1843+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1844+
def test_random_adjust_sharpness(self):
1845+
self._test_randomness(
1846+
F.adjust_sharpness,
1847+
transforms.RandomAdjustSharpness,
1848+
[{"sharpness_factor": 2.0}]
1849+
)
1850+
1851+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1852+
def test_random_autocontrast(self):
1853+
self._test_randomness(
1854+
F.autocontrast,
1855+
transforms.RandomAutocontrast,
1856+
[{}]
1857+
)
1858+
1859+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1860+
def test_random_equalize(self):
1861+
self._test_randomness(
1862+
F.equalize,
1863+
transforms.RandomEqualize,
1864+
[{}]
1865+
)
1866+
1867+
def test_autoaugment(self):
1868+
for policy in transforms.AutoAugmentPolicy:
1869+
for fill in [None, 85, (128, 128, 128)]:
1870+
random.seed(42)
1871+
img = Image.open(GRACE_HOPPER)
1872+
transform = transforms.AutoAugment(policy=policy, fill=fill)
1873+
for _ in range(100):
1874+
img = transform(img)
1875+
transform.__repr__()
1876+
17541877

17551878
if __name__ == '__main__':
17561879
unittest.main()

test/test_transforms_tensor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,34 @@ def test_random_horizontal_flip(self):
8989
def test_random_vertical_flip(self):
9090
self._test_op('vflip', 'RandomVerticalFlip')
9191

92+
def test_random_invert(self):
93+
self._test_op('invert', 'RandomInvert')
94+
95+
def test_random_posterize(self):
96+
fn_kwargs = meth_kwargs = {"bits": 4}
97+
self._test_op(
98+
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
99+
)
100+
101+
def test_random_solarize(self):
102+
fn_kwargs = meth_kwargs = {"threshold": 192.0}
103+
self._test_op(
104+
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
105+
)
106+
107+
def test_random_adjust_sharpness(self):
108+
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
109+
self._test_op(
110+
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
111+
)
112+
113+
def test_random_autocontrast(self):
114+
self._test_op('autocontrast', 'RandomAutocontrast')
115+
116+
def test_random_equalize(self):
117+
torch.set_deterministic(False)
118+
self._test_op('equalize', 'RandomEqualize')
119+
92120
def test_color_jitter(self):
93121

94122
tol = 1.0 + 1e-10
@@ -598,6 +626,22 @@ def test_convert_image_dtype(self):
598626
with get_tmp_dir() as tmp_dir:
599627
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
600628

629+
def test_autoaugment(self):
630+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
631+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
632+
633+
for policy in T.AutoAugmentPolicy:
634+
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
635+
for _ in range(100):
636+
transform = T.AutoAugment(policy=policy, fill=fill)
637+
s_transform = torch.jit.script(transform)
638+
639+
self._test_transform_vs_scripted(transform, s_transform, tensor)
640+
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
641+
642+
with get_tmp_dir() as tmp_dir:
643+
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
644+
601645

602646
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
603647
class CUDATester(Tester):

torchvision/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .transforms import *
2+
from .autoaugment import *

0 commit comments

Comments
 (0)