Skip to content

Commit e32b19e

Browse files
jdsgomesNicolasHug
andauthored
Add support for Rendered sst2 dataset (#5220)
* Adding multiweight support for shufflenetv2 prototype models * Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 31fadbe. * Adding multiweight support for shufflenetv2 prototype models * Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 4e3d900. * Add RenderedSST2 dataset * Address PR comments * Fix bug in dataset verification Co-authored-by: Nicolas Hug <[email protected]>
1 parent f670152 commit e32b19e

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
7070
PCAM
7171
PhotoTour
7272
Places365
73+
RenderedSST2
7374
QMNIST
7475
SBDataset
7576
SBU

test/test_datasets.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,5 +2665,27 @@ def inject_fake_data(self, tmpdir: str, config):
26652665
return num_images
26662666

26672667

2668+
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
2669+
DATASET_CLASS = datasets.RenderedSST2
2670+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2671+
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
2672+
2673+
def inject_fake_data(self, tmpdir: str, config):
2674+
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
2675+
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
2676+
2677+
num_images_per_class = {"train": 5, "test": 6, "val": 7}
2678+
sampled_classes = ["positive", "negative"]
2679+
for cls in sampled_classes:
2680+
datasets_utils.create_image_folder(
2681+
image_folder,
2682+
cls,
2683+
file_name_fn=lambda idx: f"{idx}.png",
2684+
num_examples=num_images_per_class[config["split"]],
2685+
)
2686+
2687+
return len(sampled_classes) * num_images_per_class[config["split"]]
2688+
2689+
26682690
if __name__ == "__main__":
26692691
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .pcam import PCAM
3030
from .phototour import PhotoTour
3131
from .places365 import Places365
32+
from .rendered_sst2 import RenderedSST2
3233
from .sbd import SBDataset
3334
from .sbu import SBU
3435
from .semeion import SEMEION
@@ -102,4 +103,5 @@
102103
"Country211",
103104
"FGVCAircraft",
104105
"EuroSAT",
106+
"RenderedSST2",
105107
)

torchvision/datasets/rendered_sst2.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pathlib import Path
2+
from typing import Any, Tuple, Callable, Optional
3+
4+
import PIL.Image
5+
6+
from .utils import verify_str_arg, download_and_extract_archive
7+
from .vision import VisionDataset
8+
9+
10+
class RenderedSST2(VisionDataset):
11+
"""`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
12+
13+
Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
14+
character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
15+
Treebank v2 dataset.
16+
17+
This dataset contains two classes (positive and negative) and is divided in three splits: a train
18+
split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
19+
(444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
20+
21+
Args:
22+
root (string): Root directory of the dataset.
23+
split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
24+
download (bool, optional): If True, downloads the dataset from the internet and
25+
puts it in root directory. If dataset is already downloaded, it is not
26+
downloaded again. Default is False.
27+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
28+
version. E.g, ``transforms.RandomCrop``.
29+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
"""
31+
32+
_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
33+
_MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
34+
35+
def __init__(
36+
self,
37+
root: str,
38+
split: str = "train",
39+
download: bool = False,
40+
transform: Optional[Callable] = None,
41+
target_transform: Optional[Callable] = None,
42+
) -> None:
43+
super().__init__(root, transform=transform, target_transform=target_transform)
44+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
45+
self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
46+
self._base_folder = Path(self.root) / "rendered-sst2"
47+
self.classes = ["negative", "positive"]
48+
self.class_to_idx = {"negative": 0, "positive": 1}
49+
50+
if download:
51+
self._download()
52+
53+
if not self._check_exists():
54+
raise RuntimeError("Dataset not found. You can use download=True to download it")
55+
56+
self._labels = []
57+
self._image_files = []
58+
59+
for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"):
60+
self._labels.append(self.class_to_idx[p.parent.name])
61+
self._image_files.append(p)
62+
63+
def __len__(self) -> int:
64+
return len(self._image_files)
65+
66+
def __getitem__(self, idx) -> Tuple[Any, Any]:
67+
image_file, label = self._image_files[idx], self._labels[idx]
68+
image = PIL.Image.open(image_file).convert("RGB")
69+
70+
if self.transform:
71+
image = self.transform(image)
72+
73+
if self.target_transform:
74+
label = self.target_transform(label)
75+
76+
return image, label
77+
78+
def extra_repr(self) -> str:
79+
return f"split={self._split}"
80+
81+
def _check_exists(self) -> bool:
82+
for class_label in set(self.classes):
83+
if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
84+
return False
85+
return True
86+
87+
def _download(self) -> None:
88+
if self._check_exists():
89+
return
90+
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

0 commit comments

Comments
 (0)