Skip to content

Commit 8e60cf4

Browse files
Philip Meierfmassa
Philip Meier
authored andcommitted
Test for SVHN (#1086)
* cast images to PIL at instantiation instead of runtime * add test for svhn * added tests for remaining SVHN splits * flake8 * rolled back changes to datasets
1 parent 0c75d99 commit 8e60cf4

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

test/fakedata_generation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,20 @@ def _make_polygon_target(file):
241241
'{city}_000000_000000_leftImg8bit.png'.format(city=city)))
242242

243243
yield tmp_dir
244+
245+
246+
@contextlib.contextmanager
247+
def svhn_root():
248+
import scipy.io as sio
249+
250+
def _make_mat(file):
251+
images = np.zeros((32, 32, 3, 2), dtype=np.uint8)
252+
targets = np.zeros((2,), dtype=np.uint8)
253+
sio.savemat(file, {'X': images, 'y': targets})
254+
255+
with get_tmp_dir() as root:
256+
_make_mat(os.path.join(root, "train_32x32.mat"))
257+
_make_mat(os.path.join(root, "test_32x32.mat"))
258+
_make_mat(os.path.join(root, "extra_32x32.mat"))
259+
260+
yield root

test/test_datasets.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from torch._utils_internal import get_file_path_2
88
import torchvision
99
from common_utils import get_tmp_dir
10-
from fakedata_generation import mnist_root, cifar_root, imagenet_root, cityscapes_root
10+
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
11+
cityscapes_root, svhn_root
1112

1213

1314
class Tester(unittest.TestCase):
@@ -185,6 +186,19 @@ def test_cityscapes(self):
185186
self.assertTrue(isinstance(output[1][1], dict)) # polygon
186187
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
187188

189+
@mock.patch('torchvision.datasets.SVHN._check_integrity')
190+
def test_svhn(self, mock_check):
191+
mock_check.return_value = True
192+
with svhn_root() as root:
193+
dataset = torchvision.datasets.SVHN(root, split="train")
194+
self.generic_classification_dataset_test(dataset, num_images=2)
195+
196+
dataset = torchvision.datasets.SVHN(root, split="test")
197+
self.generic_classification_dataset_test(dataset, num_images=2)
198+
199+
dataset = torchvision.datasets.SVHN(root, split="extra")
200+
self.generic_classification_dataset_test(dataset, num_images=2)
201+
188202

189203
if __name__ == '__main__':
190204
unittest.main()

0 commit comments

Comments
 (0)