-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add tests for the STL10 dataset #3345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
e8a08df
extract some functionality from places365 fakedata for common use
pmeier aac2169
add a common DatasetTestcase
pmeier 1c2ce43
add fakedata generation and tests for STL10
pmeier 59041ea
lint
pmeier 4c6fede
Merge branch 'master' into stl10-tests
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import contextlib | ||
import sys | ||
import os | ||
import unittest | ||
|
@@ -7,9 +8,10 @@ | |
from PIL import Image | ||
from torch._utils_internal import get_file_path_2 | ||
import torchvision | ||
from torchvision.datasets import utils | ||
from common_utils import get_tmp_dir | ||
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ | ||
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root | ||
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root | ||
import xml.etree.ElementTree as ET | ||
from urllib.request import Request, urlopen | ||
import itertools | ||
|
@@ -28,7 +30,7 @@ | |
HAS_PYAV = False | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
class DatasetTestcase(unittest.TestCase): | ||
def generic_classification_dataset_test(self, dataset, num_images=1): | ||
self.assertEqual(len(dataset), num_images) | ||
img, target = dataset[0] | ||
|
@@ -41,6 +43,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1): | |
self.assertTrue(isinstance(img, PIL.Image.Image)) | ||
self.assertTrue(isinstance(target, PIL.Image.Image)) | ||
|
||
|
||
class Tester(DatasetTestcase): | ||
def test_imagefolder(self): | ||
# TODO: create the fake data on-the-fly | ||
FAKEDATA_DIR = get_file_path_2( | ||
|
@@ -354,7 +358,7 @@ def test_places365_devkit_download(self): | |
def test_places365_devkit_no_download(self): | ||
for split in ("train-standard", "train-challenge", "val"): | ||
with self.subTest(split=split): | ||
with places365_root(split=split, extract_images=False) as places365: | ||
with places365_root(split=split) as places365: | ||
root, data = places365 | ||
|
||
with self.assertRaises(RuntimeError): | ||
|
@@ -383,12 +387,84 @@ def test_places365_images_download_preexisting(self): | |
torchvision.datasets.Places365(root, split=split, small=small, download=True) | ||
|
||
def test_places365_repr_smoke(self): | ||
with places365_root(extract_images=False) as places365: | ||
with places365_root() as places365: | ||
root, data = places365 | ||
|
||
dataset = torchvision.datasets.Places365(root, download=True) | ||
self.assertIsInstance(repr(dataset), str) | ||
|
||
|
||
class STL10Tester(DatasetTestcase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After adding the tests for |
||
@contextlib.contextmanager | ||
def mocked_root(self): | ||
with stl10_root() as (root, data): | ||
yield root, data | ||
|
||
@contextlib.contextmanager | ||
def mocked_dataset(self, pre_extract=False, download=True, **kwargs): | ||
with self.mocked_root() as (root, data): | ||
if pre_extract: | ||
utils.extract_archive(os.path.join(root, data["archive"])) | ||
dataset = torchvision.datasets.STL10(root, download=download, **kwargs) | ||
yield dataset, data | ||
|
||
def test_not_found(self): | ||
with self.assertRaises(RuntimeError): | ||
with self.mocked_dataset(download=False): | ||
pass | ||
|
||
def test_splits(self): | ||
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'): | ||
with self.mocked_dataset(split=split) as (dataset, data): | ||
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")]) | ||
self.generic_classification_dataset_test(dataset, num_images=num_images) | ||
|
||
def test_folds(self): | ||
for fold in range(10): | ||
with self.mocked_dataset(split="train", folds=fold) as (dataset, data): | ||
num_images = data["num_images_in_folds"][fold] | ||
self.assertEqual(len(dataset), num_images) | ||
|
||
def test_invalid_folds1(self): | ||
with self.assertRaises(ValueError): | ||
with self.mocked_dataset(folds=10): | ||
pass | ||
|
||
def test_invalid_folds2(self): | ||
with self.assertRaises(ValueError): | ||
with self.mocked_dataset(folds="0"): | ||
pass | ||
|
||
def test_transforms(self): | ||
expected_image = "image" | ||
expected_target = "target" | ||
|
||
def transform(image): | ||
return expected_image | ||
|
||
def target_transform(target): | ||
return expected_target | ||
|
||
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _): | ||
actual_image, actual_target = dataset[0] | ||
|
||
self.assertEqual(actual_image, expected_image) | ||
self.assertEqual(actual_target, expected_target) | ||
|
||
def test_unlabeled(self): | ||
with self.mocked_dataset(split="unlabeled") as (dataset, _): | ||
labels = [dataset[idx][1] for idx in range(len(dataset))] | ||
self.assertTrue(all([label == -1 for label in labels])) | ||
|
||
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive") | ||
def test_download_preexisting(self, mock): | ||
with self.mocked_dataset(pre_extract=True) as (dataset, data): | ||
mock.assert_not_called() | ||
|
||
def test_repr_smoke(self): | ||
with self.mocked_dataset() as (dataset, _): | ||
self.assertIsInstance(repr(dataset), str) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was necessary since
STL10Tester
needs access togeneric_classification_dataset_test
without subclassingTester
directly. Otherwise runningSTL10Tester
would also run all inherited test.