|
| 1 | +import json |
| 2 | +from pathlib import Path |
| 3 | +from typing import Any, Tuple, Callable, Optional |
| 4 | + |
| 5 | +import PIL.Image |
| 6 | + |
| 7 | +from .utils import verify_str_arg, download_and_extract_archive |
| 8 | +from .vision import VisionDataset |
| 9 | + |
| 10 | + |
| 11 | +class Food101(VisionDataset): |
| 12 | + """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_. |
| 13 | +
|
| 14 | + The Food-101 is a challenging data set of 101 food categories, with 101'000 images. |
| 15 | + For each class, 250 manually reviewed test images are provided as well as 750 training images. |
| 16 | + On purpose, the training images were not cleaned, and thus still contain some amount of noise. |
| 17 | + This comes mostly in the form of intense colors and sometimes wrong labels. All images were |
| 18 | + rescaled to have a maximum side length of 512 pixels. |
| 19 | +
|
| 20 | +
|
| 21 | + Args: |
| 22 | + root (string): Root directory of the dataset. |
| 23 | + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. |
| 24 | + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed |
| 25 | + version. E.g, ``transforms.RandomCrop``. |
| 26 | + target_transform (callable, optional): A function/transform that takes in the target and transforms it. |
| 27 | + """ |
| 28 | + |
| 29 | + _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" |
| 30 | + _MD5 = "85eeb15f3717b99a5da872d97d918f87" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + root: str, |
| 35 | + split: str = "train", |
| 36 | + download: bool = True, |
| 37 | + transform: Optional[Callable] = None, |
| 38 | + target_transform: Optional[Callable] = None, |
| 39 | + ) -> None: |
| 40 | + super().__init__(root, transform=transform, target_transform=target_transform) |
| 41 | + self._split = verify_str_arg(split, "split", ("train", "test")) |
| 42 | + self._base_folder = Path(self.root) / "food-101" |
| 43 | + self._meta_folder = self._base_folder / "meta" |
| 44 | + self._images_folder = self._base_folder / "images" |
| 45 | + |
| 46 | + if download: |
| 47 | + self._download() |
| 48 | + |
| 49 | + if not self._check_exists(): |
| 50 | + raise RuntimeError("Dataset not found. You can use download=True to download it") |
| 51 | + |
| 52 | + self._labels = [] |
| 53 | + self._image_files = [] |
| 54 | + with open(self._meta_folder / f"{split}.json", "r") as f: |
| 55 | + metadata = json.loads(f.read()) |
| 56 | + |
| 57 | + self.classes = sorted(metadata.keys()) |
| 58 | + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) |
| 59 | + |
| 60 | + for class_label, im_rel_paths in metadata.items(): |
| 61 | + self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) |
| 62 | + self._image_files += [ |
| 63 | + self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths |
| 64 | + ] |
| 65 | + |
| 66 | + def __len__(self) -> int: |
| 67 | + return len(self._image_files) |
| 68 | + |
| 69 | + def __getitem__(self, idx) -> Tuple[Any, Any]: |
| 70 | + image_file, label = self._image_files[idx], self._labels[idx] |
| 71 | + image = PIL.Image.open(image_file).convert("RGB") |
| 72 | + |
| 73 | + if self.transform: |
| 74 | + image = self.transform(image) |
| 75 | + |
| 76 | + if self.target_transform: |
| 77 | + label = self.target_transform(label) |
| 78 | + |
| 79 | + return image, label |
| 80 | + |
| 81 | + def extra_repr(self) -> str: |
| 82 | + return f"split={self._split}" |
| 83 | + |
| 84 | + def _check_exists(self) -> bool: |
| 85 | + return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) |
| 86 | + |
| 87 | + def _download(self) -> None: |
| 88 | + if self._check_exists(): |
| 89 | + return |
| 90 | + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) |
0 commit comments