@@ -19,6 +19,9 @@ class SUN397(VisionDataset):
19
19
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
20
20
parition (integer, optional): A valid partition can be an integer from 1 to 10 or None,
21
21
for the entire dataset.
22
+ download (bool, optional): If true, downloads the dataset from the internet and
23
+ puts it in root directory. If dataset is already downloaded, it is not
24
+ downloaded again.
22
25
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
23
26
version. E.g, ``transforms.RandomCrop``.
24
27
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
@@ -41,30 +44,30 @@ def __init__(
41
44
super ().__init__ (root , transform = transform , target_transform = target_transform )
42
45
self .split = verify_str_arg (split , "split" , ("train" , "test" ))
43
46
self .partition = partition
44
- self .data_dir = Path (self .root ) / "SUN397"
47
+ self ._data_dir = Path (self .root ) / "SUN397"
45
48
46
49
if self .partition is not None :
47
50
if self .partition < 0 or self .partition > 10 :
48
- raise RuntimeError ("Enter a valid integer partition from 1 to 10 or None, for entire dataset " )
51
+ raise RuntimeError (f"The partition parameter should be an int in [1, 10] or None, got { partition } . " )
49
52
50
53
if download :
51
54
self ._download ()
52
55
53
56
if not self ._check_exists ():
54
57
raise RuntimeError ("Dataset not found. You can use download=True to download it" )
55
58
56
- with open (self .data_dir / "ClassName.txt" ) as f :
59
+ with open (self ._data_dir / "ClassName.txt" ) as f :
57
60
self .classes = [c [3 :].strip () for c in f ]
58
61
59
62
self .class_to_idx = dict (zip (self .classes , range (len (self .classes ))))
60
63
if self .partition is not None :
61
- with open (self .data_dir / f"{ self .split .title ()} ing_{ self .partition :02d} .txt" , "r" ) as f :
62
- self ._image_files = [self .data_dir .joinpath (* line .strip ()[1 :].split ("/" )) for line in f ]
64
+ with open (self ._data_dir / f"{ self .split .title ()} ing_{ self .partition :02d} .txt" , "r" ) as f :
65
+ self ._image_files = [self ._data_dir .joinpath (* line .strip ()[1 :].split ("/" )) for line in f ]
63
66
else :
64
- self ._image_files = list (self .data_dir .rglob ("sun_*.jpg" ))
67
+ self ._image_files = list (self ._data_dir .rglob ("sun_*.jpg" ))
65
68
66
69
self ._labels = [
67
- self .class_to_idx ["/" .join (path .relative_to (self .data_dir ).parts [1 :- 1 ])] for path in self ._image_files
70
+ self .class_to_idx ["/" .join (path .relative_to (self ._data_dir ).parts [1 :- 1 ])] for path in self ._image_files
68
71
]
69
72
70
73
def __len__ (self ) -> int :
@@ -83,13 +86,13 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:
83
86
return image , label
84
87
85
88
def _check_exists (self ) -> bool :
86
- return self .data_dir .exists () and self .data_dir .is_dir ()
89
+ return self ._data_dir .exists () and self ._data_dir .is_dir ()
87
90
88
91
def extra_repr (self ) -> str :
89
92
return "Split: {split}" .format (** self .__dict__ )
90
93
91
94
def _download (self ) -> None :
92
- if self ._check_exists :
95
+ if self ._check_exists () :
93
96
return
94
- download_and_extract_archive (self ._DATASET_URL , download_root = self .root , md5 = self ._MD5 )
95
- download_and_extract_archive (self ._PARTITIONS_URL , download_root = str (self .data_dir ), md5 = self ._PARTITIONS_MD5 )
97
+ download_and_extract_archive (self ._DATASET_URL , download_root = self .root , md5 = self ._DATASET_MD5 )
98
+ download_and_extract_archive (self ._PARTITIONS_URL , download_root = str (self ._data_dir ), md5 = self ._PARTITIONS_MD5 )
0 commit comments