|
7 | 7 | from torch._utils_internal import get_file_path_2
|
8 | 8 | import torchvision
|
9 | 9 | from common_utils import get_tmp_dir
|
10 |
| -from fakedata_generation import mnist_root, cifar_root, imagenet_root, cityscapes_root |
| 10 | +from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ |
| 11 | + cityscapes_root, svhn_root |
11 | 12 |
|
12 | 13 |
|
13 | 14 | class Tester(unittest.TestCase):
|
@@ -185,6 +186,19 @@ def test_cityscapes(self):
|
185 | 186 | self.assertTrue(isinstance(output[1][1], dict)) # polygon
|
186 | 187 | self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
|
187 | 188 |
|
| 189 | + @mock.patch('torchvision.datasets.SVHN._check_integrity') |
| 190 | + def test_svhn(self, mock_check): |
| 191 | + mock_check.return_value = True |
| 192 | + with svhn_root() as root: |
| 193 | + dataset = torchvision.datasets.SVHN(root, split="train") |
| 194 | + self.generic_classification_dataset_test(dataset, num_images=2) |
| 195 | + |
| 196 | + dataset = torchvision.datasets.SVHN(root, split="test") |
| 197 | + self.generic_classification_dataset_test(dataset, num_images=2) |
| 198 | + |
| 199 | + dataset = torchvision.datasets.SVHN(root, split="extra") |
| 200 | + self.generic_classification_dataset_test(dataset, num_images=2) |
| 201 | + |
188 | 202 |
|
189 | 203 | if __name__ == '__main__':
|
190 | 204 | unittest.main()
|
0 commit comments