Skip to content

Commit 6c44f6c

Browse files
prabhat00155NicolasHugpmeier
authored andcommitted
[fbsync] Add SUN397 Dataset (#5132)
Summary: * dataset class added * fix code format * fixed requested changes * fixed issues in sun397 * Update torchvision/datasets/sun397.py Reviewed By: sallysyw Differential Revision: D33479277 fbshipit-source-id: 374d098c261adeacd073fae141380130a6c3aa95 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent 4027ebc commit 6c44f6c

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
7171
SEMEION
7272
Sintel
7373
STL10
74+
SUN397
7475
SVHN
7576
UCF101
7677
USPS

test/test_datasets.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,52 @@ def inject_fake_data(self, tmpdir: str, config):
22062206
return len(sampled_classes * n_samples_per_class)
22072207

22082208

2209+
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
2210+
DATASET_CLASS = datasets.SUN397
2211+
2212+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2213+
split=("train", "test"),
2214+
partition=(1, 10, None),
2215+
)
2216+
2217+
def inject_fake_data(self, tmpdir: str, config):
2218+
data_dir = pathlib.Path(tmpdir) / "SUN397"
2219+
data_dir.mkdir()
2220+
2221+
num_images_per_class = 5
2222+
sampled_classes = ("abbey", "airplane_cabin", "airport_terminal")
2223+
im_paths = []
2224+
2225+
for cls in sampled_classes:
2226+
image_folder = data_dir / cls[0]
2227+
im_paths.extend(
2228+
datasets_utils.create_image_folder(
2229+
image_folder,
2230+
image_folder / cls,
2231+
file_name_fn=lambda idx: f"sun_{idx}.jpg",
2232+
num_examples=num_images_per_class,
2233+
)
2234+
)
2235+
2236+
with open(data_dir / "ClassName.txt", "w") as file:
2237+
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
2238+
2239+
if config["partition"] is not None:
2240+
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
2241+
2242+
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
2243+
file.writelines(
2244+
"\n".join(
2245+
f"/{f_path.relative_to(data_dir).as_posix()}"
2246+
for f_path in random.choices(im_paths, k=num_samples)
2247+
)
2248+
)
2249+
else:
2250+
num_samples = len(im_paths)
2251+
2252+
return num_samples
2253+
2254+
22092255
class DTDTestCase(datasets_utils.ImageDatasetTestCase):
22102256
DATASET_CLASS = datasets.DTD
22112257
FEATURE_TYPES = (PIL.Image.Image, int)

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .sbu import SBU
2929
from .semeion import SEMEION
3030
from .stl10 import STL10
31+
from .sun397 import SUN397
3132
from .svhn import SVHN
3233
from .ucf101 import UCF101
3334
from .usps import USPS
@@ -51,6 +52,7 @@
5152
"MNIST",
5253
"KMNIST",
5354
"STL10",
55+
"SUN397",
5456
"SVHN",
5557
"PhotoTour",
5658
"SEMEION",

torchvision/datasets/sun397.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from pathlib import Path
2+
from typing import Any, Tuple, Callable, Optional
3+
4+
import PIL.Image
5+
6+
from .utils import verify_str_arg, download_and_extract_archive
7+
from .vision import VisionDataset
8+
9+
10+
class SUN397(VisionDataset):
11+
"""`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
12+
13+
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
14+
397 categories with 108'754 images. The dataset also provides 10 partitions for training
15+
and testing, with each partition consisting of 50 images per class.
16+
17+
Args:
18+
root (string): Root directory of the dataset.
19+
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
20+
partition (int, optional): A valid partition can be an integer from 1 to 10 or None,
21+
for the entire dataset.
22+
download (bool, optional): If true, downloads the dataset from the internet and
23+
puts it in root directory. If dataset is already downloaded, it is not
24+
downloaded again.
25+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
26+
version. E.g, ``transforms.RandomCrop``.
27+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
28+
"""
29+
30+
_DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
31+
_DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
32+
_PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
33+
_PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881"
34+
35+
def __init__(
36+
self,
37+
root: str,
38+
split: str = "train",
39+
partition: Optional[int] = 1,
40+
download: bool = True,
41+
transform: Optional[Callable] = None,
42+
target_transform: Optional[Callable] = None,
43+
) -> None:
44+
super().__init__(root, transform=transform, target_transform=target_transform)
45+
self.split = verify_str_arg(split, "split", ("train", "test"))
46+
self.partition = partition
47+
self._data_dir = Path(self.root) / "SUN397"
48+
49+
if self.partition is not None:
50+
if self.partition < 0 or self.partition > 10:
51+
raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")
52+
53+
if download:
54+
self._download()
55+
56+
if not self._check_exists():
57+
raise RuntimeError("Dataset not found. You can use download=True to download it")
58+
59+
with open(self._data_dir / "ClassName.txt") as f:
60+
self.classes = [c[3:].strip() for c in f]
61+
62+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
63+
if self.partition is not None:
64+
with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
65+
self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
66+
else:
67+
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
68+
69+
self._labels = [
70+
self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
71+
]
72+
73+
def __len__(self) -> int:
74+
return len(self._image_files)
75+
76+
def __getitem__(self, idx) -> Tuple[Any, Any]:
77+
image_file, label = self._image_files[idx], self._labels[idx]
78+
image = PIL.Image.open(image_file).convert("RGB")
79+
80+
if self.transform:
81+
image = self.transform(image)
82+
83+
if self.target_transform:
84+
label = self.target_transform(label)
85+
86+
return image, label
87+
88+
def _check_exists(self) -> bool:
89+
return self._data_dir.exists() and self._data_dir.is_dir()
90+
91+
def extra_repr(self) -> str:
92+
return "Split: {split}".format(**self.__dict__)
93+
94+
def _download(self) -> None:
95+
if self._check_exists():
96+
return
97+
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
98+
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)

0 commit comments

Comments
 (0)