|
| 1 | +from pathlib import Path |
| 2 | +from typing import Any, Tuple, Callable, Optional |
| 3 | + |
| 4 | +import PIL.Image |
| 5 | + |
| 6 | +from .utils import verify_str_arg, download_and_extract_archive |
| 7 | +from .vision import VisionDataset |
| 8 | + |
| 9 | + |
| 10 | +class SUN397(VisionDataset): |
| 11 | + """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_. |
| 12 | +
|
| 13 | + The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of |
| 14 | + 397 categories with 108'754 images. The dataset also provides 10 partitions for training |
| 15 | + and testing, with each partition consisting of 50 images per class. |
| 16 | +
|
| 17 | + Args: |
| 18 | + root (string): Root directory of the dataset. |
| 19 | + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. |
| 20 | + partition (int, optional): A valid partition can be an integer from 1 to 10 or None, |
| 21 | + for the entire dataset. |
| 22 | + download (bool, optional): If true, downloads the dataset from the internet and |
| 23 | + puts it in root directory. If dataset is already downloaded, it is not |
| 24 | + downloaded again. |
| 25 | + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed |
| 26 | + version. E.g, ``transforms.RandomCrop``. |
| 27 | + target_transform (callable, optional): A function/transform that takes in the target and transforms it. |
| 28 | + """ |
| 29 | + |
| 30 | + _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" |
| 31 | + _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" |
| 32 | + _PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip" |
| 33 | + _PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881" |
| 34 | + |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + root: str, |
| 38 | + split: str = "train", |
| 39 | + partition: Optional[int] = 1, |
| 40 | + download: bool = True, |
| 41 | + transform: Optional[Callable] = None, |
| 42 | + target_transform: Optional[Callable] = None, |
| 43 | + ) -> None: |
| 44 | + super().__init__(root, transform=transform, target_transform=target_transform) |
| 45 | + self.split = verify_str_arg(split, "split", ("train", "test")) |
| 46 | + self.partition = partition |
| 47 | + self._data_dir = Path(self.root) / "SUN397" |
| 48 | + |
| 49 | + if self.partition is not None: |
| 50 | + if self.partition < 0 or self.partition > 10: |
| 51 | + raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.") |
| 52 | + |
| 53 | + if download: |
| 54 | + self._download() |
| 55 | + |
| 56 | + if not self._check_exists(): |
| 57 | + raise RuntimeError("Dataset not found. You can use download=True to download it") |
| 58 | + |
| 59 | + with open(self._data_dir / "ClassName.txt") as f: |
| 60 | + self.classes = [c[3:].strip() for c in f] |
| 61 | + |
| 62 | + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) |
| 63 | + if self.partition is not None: |
| 64 | + with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: |
| 65 | + self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f] |
| 66 | + else: |
| 67 | + self._image_files = list(self._data_dir.rglob("sun_*.jpg")) |
| 68 | + |
| 69 | + self._labels = [ |
| 70 | + self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files |
| 71 | + ] |
| 72 | + |
| 73 | + def __len__(self) -> int: |
| 74 | + return len(self._image_files) |
| 75 | + |
| 76 | + def __getitem__(self, idx) -> Tuple[Any, Any]: |
| 77 | + image_file, label = self._image_files[idx], self._labels[idx] |
| 78 | + image = PIL.Image.open(image_file).convert("RGB") |
| 79 | + |
| 80 | + if self.transform: |
| 81 | + image = self.transform(image) |
| 82 | + |
| 83 | + if self.target_transform: |
| 84 | + label = self.target_transform(label) |
| 85 | + |
| 86 | + return image, label |
| 87 | + |
| 88 | + def _check_exists(self) -> bool: |
| 89 | + return self._data_dir.exists() and self._data_dir.is_dir() |
| 90 | + |
| 91 | + def extra_repr(self) -> str: |
| 92 | + return "Split: {split}".format(**self.__dict__) |
| 93 | + |
| 94 | + def _download(self) -> None: |
| 95 | + if self._check_exists(): |
| 96 | + return |
| 97 | + download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) |
| 98 | + download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5) |
0 commit comments