Skip to content

Commit 65676b4

Browse files
jdsgomespmeier
andauthored
Food 101 dataset (#5119)
* Adding multiweight support for shufflenetv2 prototype models * Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 31fadbe. * Adding multiweight support for shufflenetv2 prototype models * Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 4e3d900. * Add Food101 Dataset Addresses #5108. cc @pmeier @NicolasHug * Remove unecessary Path contructor calls * Remove unecessary Path contructor callsi and fix types * Fix tests * Address PR comments from @pmeier * Fix bug in tests and in food101 dataset * Fix bug in tests and in food101 dataset * Update torchvision/datasets/food101.py Co-authored-by: Philip Meier <[email protected]>
1 parent 8096c1b commit 65676b4

File tree

4 files changed

+130
-0
lines changed

4 files changed

+130
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4545
Flickr30k
4646
FlyingChairs
4747
FlyingThings3D
48+
Food101
4849
HD1K
4950
HMDB51
5051
ImageNet

test/test_datasets.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,5 +2168,42 @@ def inject_fake_data(self, tmpdir, config):
21682168
return num_sequences * (num_examples_per_sequence - 1)
21692169

21702170

2171+
class Food101TestCase(datasets_utils.ImageDatasetTestCase):
2172+
DATASET_CLASS = datasets.Food101
2173+
FEATURE_TYPES = (PIL.Image.Image, int)
2174+
2175+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2176+
2177+
def inject_fake_data(self, tmpdir: str, config):
2178+
root_folder = pathlib.Path(tmpdir) / "food-101"
2179+
image_folder = root_folder / "images"
2180+
meta_folder = root_folder / "meta"
2181+
2182+
image_folder.mkdir(parents=True)
2183+
meta_folder.mkdir()
2184+
2185+
num_images_per_class = 5
2186+
2187+
metadata = {}
2188+
n_samples_per_class = 3 if config["split"] == "train" else 2
2189+
sampled_classes = ("apple_pie", "crab_cakes", "gyoza")
2190+
for cls in sampled_classes:
2191+
im_fnames = datasets_utils.create_image_folder(
2192+
image_folder,
2193+
cls,
2194+
file_name_fn=lambda idx: f"{idx}.jpg",
2195+
num_examples=num_images_per_class,
2196+
)
2197+
metadata[cls] = [
2198+
"/".join(fname.relative_to(image_folder).with_suffix("").parts)
2199+
for fname in random.choices(im_fnames, k=n_samples_per_class)
2200+
]
2201+
2202+
with open(meta_folder / f"{config['split']}.json", "w") as file:
2203+
file.write(json.dumps(metadata))
2204+
2205+
return len(sampled_classes * n_samples_per_class)
2206+
2207+
21712208
if __name__ == "__main__":
21722209
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .fakedata import FakeData
88
from .flickr import Flickr8k, Flickr30k
99
from .folder import ImageFolder, DatasetFolder
10+
from .food101 import Food101
1011
from .hmdb51 import HMDB51
1112
from .imagenet import ImageNet
1213
from .inaturalist import INaturalist
@@ -77,4 +78,5 @@
7778
"FlyingChairs",
7879
"FlyingThings3D",
7980
"HD1K",
81+
"Food101",
8082
)

torchvision/datasets/food101.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Any, Tuple, Callable, Optional
4+
5+
import PIL.Image
6+
7+
from .utils import verify_str_arg, download_and_extract_archive
8+
from .vision import VisionDataset
9+
10+
11+
class Food101(VisionDataset):
12+
"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
13+
14+
The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
15+
For each class, 250 manually reviewed test images are provided as well as 750 training images.
16+
On purpose, the training images were not cleaned, and thus still contain some amount of noise.
17+
This comes mostly in the form of intense colors and sometimes wrong labels. All images were
18+
rescaled to have a maximum side length of 512 pixels.
19+
20+
21+
Args:
22+
root (string): Root directory of the dataset.
23+
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
24+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
25+
version. E.g, ``transforms.RandomCrop``.
26+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
"""
28+
29+
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
30+
_MD5 = "85eeb15f3717b99a5da872d97d918f87"
31+
32+
def __init__(
33+
self,
34+
root: str,
35+
split: str = "train",
36+
download: bool = True,
37+
transform: Optional[Callable] = None,
38+
target_transform: Optional[Callable] = None,
39+
) -> None:
40+
super().__init__(root, transform=transform, target_transform=target_transform)
41+
self._split = verify_str_arg(split, "split", ("train", "test"))
42+
self._base_folder = Path(self.root) / "food-101"
43+
self._meta_folder = self._base_folder / "meta"
44+
self._images_folder = self._base_folder / "images"
45+
46+
if download:
47+
self._download()
48+
49+
if not self._check_exists():
50+
raise RuntimeError("Dataset not found. You can use download=True to download it")
51+
52+
self._labels = []
53+
self._image_files = []
54+
with open(self._meta_folder / f"{split}.json", "r") as f:
55+
metadata = json.loads(f.read())
56+
57+
self.classes = sorted(metadata.keys())
58+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
59+
60+
for class_label, im_rel_paths in metadata.items():
61+
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
62+
self._image_files += [
63+
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
64+
]
65+
66+
def __len__(self) -> int:
67+
return len(self._image_files)
68+
69+
def __getitem__(self, idx) -> Tuple[Any, Any]:
70+
image_file, label = self._image_files[idx], self._labels[idx]
71+
image = PIL.Image.open(image_file).convert("RGB")
72+
73+
if self.transform:
74+
image = self.transform(image)
75+
76+
if self.target_transform:
77+
label = self.target_transform(label)
78+
79+
return image, label
80+
81+
def extra_repr(self) -> str:
82+
return f"split={self._split}"
83+
84+
def _check_exists(self) -> bool:
85+
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
86+
87+
def _download(self) -> None:
88+
if self._check_exists():
89+
return
90+
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

0 commit comments

Comments
 (0)