Skip to content

Commit 1407dbd

Browse files
committed
Address PR comments from @pmeier
1 parent 72a8eaa commit 1407dbd

File tree

2 files changed

+10
-27
lines changed

2 files changed

+10
-27
lines changed

test/test_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,7 @@ def inject_fake_data(self, tmpdir: str, config):
21802180
meta_folder = root_folder / "meta"
21812181

21822182
image_folder.mkdir(parents=True)
2183-
meta_folder.mkdir(parents=True)
2183+
meta_folder.mkdir()
21842184

21852185
num_images_per_class = 5
21862186

torchvision/datasets/food101.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class Food101(VisionDataset):
2121
Args:
2222
root (string): Root directory of the dataset.
2323
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
24-
2524
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
2625
version. E.g, ``transforms.RandomCrop``.
2726
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
@@ -41,6 +40,9 @@ def __init__(
4140
super().__init__(root, transform=transform, target_transform=target_transform)
4241
self._split = verify_str_arg(split, "split", ("train", "test"))
4342
self._root_path = Path(self.root)
43+
self._base_folder = self._root_path / "food-101"
44+
self._meta_folder = self._base_folder / "meta"
45+
self._images_folder = self._base_folder / "images"
4446

4547
if download:
4648
self._download()
@@ -53,12 +55,12 @@ def __init__(
5355
with open(self._meta_folder / f"{split}.json", "r") as f:
5456
metadata = json.loads(f.read())
5557

56-
self.classes = sorted(set(metadata.keys()))
58+
self.classes = sorted(metadata.keys())
5759
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
5860

59-
for class_label, im_paths in metadata.items():
60-
self._labels += [self.class_to_idx[class_label]] * len(im_paths)
61-
self._image_files += im_paths
61+
for class_label, im_ids in metadata.items():
62+
self._labels += [self.class_to_idx[class_label]] * len(im_ids)
63+
self._image_files += [self._images_folder.joinpath(*f"{im_id}.jpg".split("/")) for im_id in im_ids]
6264

6365
def __len__(self) -> int:
6466
return len(self._image_files)
@@ -78,29 +80,10 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:
7880
def extra_repr(self) -> str:
7981
return f"split={self._split}"
8082

81-
@property
82-
def _base_folder(self) -> Path:
83-
return self._root_path / "food-101"
84-
85-
@property
86-
def _meta_folder(self) -> Path:
87-
return self._base_folder / "meta"
88-
89-
@property
90-
def _images_folder(self) -> Path:
91-
return self._base_folder / "images"
92-
9383
def _check_exists(self) -> bool:
94-
return (
95-
self._base_folder.exists()
96-
and self._base_folder.is_dir()
97-
and self._meta_folder.exists()
98-
and self._meta_folder.is_dir()
99-
and self._images_folder.exists()
100-
and self._images_folder.is_dir()
101-
)
84+
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
10285

10386
def _download(self) -> None:
10487
if self._check_exists():
10588
return
106-
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
89+
download_and_extract_archive(self._URL, download_root=str(self.root), md5=self._MD5)

0 commit comments

Comments
 (0)