Skip to content

Adding fvgc_aircraft dataset #5178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 14, 2022
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
FGVCAircraft
GTSRB
HD1K
HMDB51
Expand Down
51 changes: 51 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,57 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
annotation_level = config["annotation_level"]
annotation_level_to_file = {
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}

root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
data_folder = root_folder / "data"

classes = ["707-320", "Hawk T1", "Tornado"]
num_images_per_class = 5

datasets_utils.create_image_folder(
data_folder,
"images",
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_images_per_class * len(classes),
)

annotation_file = data_folder / annotation_level_to_file[annotation_level]
with open(annotation_file, "w") as file:
file.write("\n".join(classes))

num_samples_per_class = 4 if split == "trainval" else 2
images_classes = []
for i in range(len(classes)):
images_classes.extend(
[
f"{idx} {classes[i]}"
for idx in random.sample(
range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class
)
]
)

images_annotation_file = data_folder / f"images_{annotation_level}_{split}.txt"
with open(images_annotation_file, "w") as file:
file.write("\n".join(images_classes))

return len(classes * num_samples_per_class)


class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dtd import DTD
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr8k, Flickr30k
from .flowers102 import Flowers102
from .folder import ImageFolder, DatasetFolder
Expand Down Expand Up @@ -95,4 +96,5 @@
"CLEVRClassification",
"OxfordIIITPet",
"Country211",
"FGVCAircraft",
)
114 changes: 114 additions & 0 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import os
from typing import Any, Callable, Optional, Tuple

import PIL.Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class FGVCAircraft(VisionDataset):
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.

The dataset contains 10,200 images of aircraft, with 100 images for each of 102
different aircraft model variants, most of which are airplanes.
Aircraft models are organized in a three-levels hierarchy. The three levels, from
finer to coarser, are:

- ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
indistinguishable into one class. The dataset comprises 102 different variants.
- ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
- ``manufacturer``, e.g. Boeing. The dataset comprises 41 different manufacturers.

Args:
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"

def __init__(
self,
root: str,
split: str = "trainval",
download: bool = False,
annotation_level: str = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._annotation_level = verify_str_arg(
annotation_level, "annotation_level", ("variant", "family", "manufacturer")
)

self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

annotation_file = os.path.join(
self._data_path,
"data",
{
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}[self._annotation_level],
)
with open(annotation_file, "r") as f:
self.classes = [line.strip() for line in f]

self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))

image_data_folder = os.path.join(self._data_path, "data", "images")
labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")

self._image_files = []
self._labels = []

with open(labels_file, "r") as f:
for line in f:
image_name, label_name = line.strip().split(" ", 1)
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
self._labels.append(self.class_to_idx[label_name])

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def _download(self) -> None:
"""
Download the FGVC Aircraft dataset archive and extract it under root.
"""
if self._check_exists():
return
download_and_extract_archive(self._URL, self.root)

def _check_exists(self) -> bool:
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)