Skip to content

Commit a8f2ded

Browse files
zhiqwangpmeierNicolasHug
authored
Add Flowers102 dataset (#5177)
* Add Flowers102 datasets * Fix initialization of images and labels * Fix _check_exists in Flowers102 * Add Flowers102 to datasets and docs * Add Flowers102TestCase to unittest * Fixing Python type statically * Shuffle the fake labels * Update test/test_datasets.py Co-authored-by: Philip Meier <[email protected]> * Apply the suggestions by pmeier * Use check_integrity to check file existence * Save the labels to base_folder * Minor fixes * Using a loop makes this more concise without reducing readability Co-authored-by: Philip Meier <[email protected]> * Using a loop makes this more concise without reducing readability Co-authored-by: Philip Meier <[email protected]> * Remove self.labels and self.label_to_index attributes * minor simplification * Check the exitence of image folder * Revert the check * Check the existence of image folder * valid -> val * keep some stuff private * minor doc arrangements * remove default FEATURE_TYPES * Simplify the datasets existence * check if the image folder exists Co-authored-by: Philip Meier <[email protected]> * isdir -> is_dir Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 1c63096 commit a8f2ded

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4646
FER2013
4747
Flickr8k
4848
Flickr30k
49+
Flowers102
4950
FlyingChairs
5051
FlyingThings3D
5152
Food101

test/test_datasets.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,5 +2490,41 @@ def inject_fake_data(self, tmpdir: str, config):
24902490
return num_examples * len(classes)
24912491

24922492

2493+
class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
2494+
DATASET_CLASS = datasets.Flowers102
2495+
2496+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2497+
REQUIRED_PACKAGES = ("scipy",)
2498+
2499+
def inject_fake_data(self, tmpdir: str, config):
2500+
base_folder = pathlib.Path(tmpdir) / "flowers-102"
2501+
2502+
num_classes = 3
2503+
num_images_per_split = dict(train=5, val=4, test=3)
2504+
num_images_total = sum(num_images_per_split.values())
2505+
datasets_utils.create_image_folder(
2506+
base_folder,
2507+
"jpg",
2508+
file_name_fn=lambda idx: f"image_{idx + 1:05d}.jpg",
2509+
num_examples=num_images_total,
2510+
)
2511+
2512+
label_dict = dict(
2513+
labels=np.random.randint(1, num_classes + 1, size=(1, num_images_total), dtype=np.uint8),
2514+
)
2515+
datasets_utils.lazy_importer.scipy.io.savemat(str(base_folder / "imagelabels.mat"), label_dict)
2516+
2517+
setid_mat = np.arange(1, num_images_total + 1, dtype=np.uint16)
2518+
np.random.shuffle(setid_mat)
2519+
setid_dict = dict(
2520+
trnid=setid_mat[: num_images_per_split["train"]].reshape(1, -1),
2521+
valid=setid_mat[num_images_per_split["train"] : -num_images_per_split["test"]].reshape(1, -1),
2522+
tstid=setid_mat[-num_images_per_split["test"] :].reshape(1, -1),
2523+
)
2524+
datasets_utils.lazy_importer.scipy.io.savemat(str(base_folder / "setid.mat"), setid_dict)
2525+
2526+
return num_images_per_split[config["split"]]
2527+
2528+
24932529
if __name__ == "__main__":
24942530
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .fakedata import FakeData
1111
from .fer2013 import FER2013
1212
from .flickr import Flickr8k, Flickr30k
13+
from .flowers102 import Flowers102
1314
from .folder import ImageFolder, DatasetFolder
1415
from .food101 import Food101
1516
from .gtsrb import GTSRB
@@ -61,6 +62,7 @@
6162
"SBU",
6263
"Flickr8k",
6364
"Flickr30k",
65+
"Flowers102",
6466
"VOCSegmentation",
6567
"VOCDetection",
6668
"Cityscapes",

torchvision/datasets/flowers102.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from pathlib import Path
2+
from typing import Any, Tuple, Callable, Optional
3+
4+
import PIL.Image
5+
6+
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
7+
from .vision import VisionDataset
8+
9+
10+
class Flowers102(VisionDataset):
11+
"""`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
12+
13+
.. warning::
14+
15+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
16+
17+
Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
18+
flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
19+
between 40 and 258 images.
20+
21+
The images have large scale, pose and light variations. In addition, there are categories that
22+
have large variations within the category, and several very similar categories.
23+
24+
Args:
25+
root (string): Root directory of the dataset.
26+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
27+
download (bool, optional): If true, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again.
30+
transform (callable, optional): A function/transform that takes in an PIL image and returns a
31+
transformed version. E.g, ``transforms.RandomCrop``.
32+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
33+
"""
34+
35+
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
36+
_file_dict = { # filename, md5
37+
"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
38+
"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
39+
"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
40+
}
41+
_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
42+
43+
def __init__(
44+
self,
45+
root: str,
46+
split: str = "train",
47+
download: bool = True,
48+
transform: Optional[Callable] = None,
49+
target_transform: Optional[Callable] = None,
50+
) -> None:
51+
super().__init__(root, transform=transform, target_transform=target_transform)
52+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
53+
self._base_folder = Path(self.root) / "flowers-102"
54+
self._images_folder = self._base_folder / "jpg"
55+
56+
if download:
57+
self.download()
58+
59+
if not self._check_integrity():
60+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
61+
62+
from scipy.io import loadmat
63+
64+
set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
65+
image_ids = set_ids[self._splits_map[self._split]].tolist()
66+
67+
labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
68+
image_id_to_label = dict(enumerate(labels["labels"].tolist(), 1))
69+
70+
self._labels = []
71+
self._image_files = []
72+
for image_id in image_ids:
73+
self._labels.append(image_id_to_label[image_id])
74+
self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
75+
76+
def __len__(self) -> int:
77+
return len(self._image_files)
78+
79+
def __getitem__(self, idx) -> Tuple[Any, Any]:
80+
image_file, label = self._image_files[idx], self._labels[idx]
81+
image = PIL.Image.open(image_file).convert("RGB")
82+
83+
if self.transform:
84+
image = self.transform(image)
85+
86+
if self.target_transform:
87+
label = self.target_transform(label)
88+
89+
return image, label
90+
91+
def extra_repr(self) -> str:
92+
return f"split={self._split}"
93+
94+
def _check_integrity(self):
95+
if not (self._images_folder.exists() and self._images_folder.is_dir()):
96+
return False
97+
98+
for id in ["label", "setid"]:
99+
filename, md5 = self._file_dict[id]
100+
if not check_integrity(str(self._base_folder / filename), md5):
101+
return False
102+
return True
103+
104+
def download(self):
105+
if self._check_integrity():
106+
return
107+
download_and_extract_archive(
108+
f"{self._download_url_prefix}{self._file_dict['image'][0]}",
109+
str(self._base_folder),
110+
md5=self._file_dict["image"][1],
111+
)
112+
for id in ["label", "setid"]:
113+
filename, md5 = self._file_dict[id]
114+
download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)

0 commit comments

Comments
 (0)