Skip to content

Commit 61b4f05

Browse files
committed
fixed issues in sun397
1 parent 8704eab commit 61b4f05

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

torchvision/datasets/sun397.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class SUN397(VisionDataset):
1919
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
2020
parition (integer, optional): A valid partition can be an integer from 1 to 10 or None,
2121
for the entire dataset.
22+
download (bool, optional): If true, downloads the dataset from the internet and
23+
puts it in root directory. If dataset is already downloaded, it is not
24+
downloaded again.
2225
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
2326
version. E.g, ``transforms.RandomCrop``.
2427
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
@@ -41,30 +44,30 @@ def __init__(
4144
super().__init__(root, transform=transform, target_transform=target_transform)
4245
self.split = verify_str_arg(split, "split", ("train", "test"))
4346
self.partition = partition
44-
self.data_dir = Path(self.root) / "SUN397"
47+
self._data_dir = Path(self.root) / "SUN397"
4548

4649
if self.partition is not None:
4750
if self.partition < 0 or self.partition > 10:
48-
raise RuntimeError("Enter a valid integer partition from 1 to 10 or None, for entire dataset")
51+
raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")
4952

5053
if download:
5154
self._download()
5255

5356
if not self._check_exists():
5457
raise RuntimeError("Dataset not found. You can use download=True to download it")
5558

56-
with open(self.data_dir / "ClassName.txt") as f:
59+
with open(self._data_dir / "ClassName.txt") as f:
5760
self.classes = [c[3:].strip() for c in f]
5861

5962
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
6063
if self.partition is not None:
61-
with open(self.data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
62-
self._image_files = [self.data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
64+
with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
65+
self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
6366
else:
64-
self._image_files = list(self.data_dir.rglob("sun_*.jpg"))
67+
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
6568

6669
self._labels = [
67-
self.class_to_idx["/".join(path.relative_to(self.data_dir).parts[1:-1])] for path in self._image_files
70+
self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
6871
]
6972

7073
def __len__(self) -> int:
@@ -83,13 +86,13 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:
8386
return image, label
8487

8588
def _check_exists(self) -> bool:
86-
return self.data_dir.exists() and self.data_dir.is_dir()
89+
return self._data_dir.exists() and self._data_dir.is_dir()
8790

8891
def extra_repr(self) -> str:
8992
return "Split: {split}".format(**self.__dict__)
9093

9194
def _download(self) -> None:
92-
if self._check_exists:
95+
if self._check_exists():
9396
return
94-
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._MD5)
95-
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self.data_dir), md5=self._PARTITIONS_MD5)
97+
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
98+
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)

0 commit comments

Comments
 (0)