Skip to content

Commit 3d39a62

Browse files
author
Philip Meier
committed
add test for svhn
1 parent 3f84497 commit 3d39a62

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

test/fakedata_generation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,20 @@ def _make_devkit_archive(root):
168168
_make_devkit_archive(root)
169169

170170
yield root
171+
172+
173+
@contextlib.contextmanager
174+
def svhn_root():
175+
import scipy.io as sio
176+
177+
def _make_mat(file):
178+
images = np.zeros((32, 32, 3, 2), dtype=np.uint8)
179+
targets = np.zeros((2,), dtype=np.uint8)
180+
sio.savemat(file, {'X': images, 'y': targets})
181+
182+
with get_tmp_dir() as root:
183+
_make_mat(os.path.join(root, "train_32x32.mat"))
184+
_make_mat(os.path.join(root, "test_32x32.mat"))
185+
_make_mat(os.path.join(root, "extra_32x32.mat"))
186+
187+
yield root

test/test_datasets.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch._utils_internal import get_file_path_2
77
import torchvision
88
from common_utils import get_tmp_dir
9-
from fakedata_generation import mnist_root, cifar_root, imagenet_root
9+
from fakedata_generation import mnist_root, cifar_root, imagenet_root, svhn_root
1010

1111

1212
class Tester(unittest.TestCase):
@@ -133,6 +133,13 @@ def test_cifar100(self, mock_ext_check, mock_int_check):
133133
img, target = dataset[0]
134134
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
135135

136+
@mock.patch('torchvision.datasets.SVHN._check_integrity')
137+
def test_svhn(self, mock_check):
138+
mock_check.return_value = True
139+
with svhn_root() as root:
140+
dataset = torchvision.datasets.SVHN(root, split="train")
141+
self.generic_classification_dataset_test(dataset, num_images=2)
142+
136143

137144
if __name__ == '__main__':
138145
unittest.main()

0 commit comments

Comments
 (0)