Skip to content

Adding fvgc_aircraft dataset #5178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 14, 2022
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
FVGCAircraft
GTSRB
HD1K
HMDB51
Expand Down
46 changes: 46 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,52 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class FVGCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FVGCAircraft
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "trainval", "test"))

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
data_folder = root_folder / "data"

num_images_per_class = 5
variants = ["707-320", "Hawk T1", "Tornado"]
n_samples_per_class = 4 if split == "trainval" else 2

datasets_utils.create_image_folder(
data_folder,
"images",
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_images_per_class * len(variants),
)

images_variants = []
for i in range(len(variants)):
variant = variants[i]
images_variants.extend(
[
f"{idx} {variant}"
for idx in random.sample(
range(i * num_images_per_class, (i + 1) * num_images_per_class), n_samples_per_class
)
]
)

varients_file = root_folder / "data" / "variants.txt"
images_variant_file = root_folder / "data" / f"images_variant_{split}.txt"

with open(varients_file, "w") as file:
file.write("\n".join(variants))

with open(images_variant_file, "w") as file:
file.write("\n".join(images_variants))

return len(variants * n_samples_per_class)


class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
from .fvgc_aircraft import FVGCAircraft
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
Expand Down Expand Up @@ -91,4 +92,5 @@
"GTSRB",
"CLEVRClassification",
"OxfordIIITPet",
"FVGCAircraft",
)
119 changes: 119 additions & 0 deletions torchvision/datasets/fvgc_aircraft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import shutil
from typing import Any, Callable, List, Optional, Tuple

import PIL.Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class FVGCAircraft(VisionDataset):
"""`FVGC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.

The dataset contains 10,200 images of aircraft, with 100 images for each of 102
different aircraft model variants, most of which are airplanes.

Args:
root (string): Root directory of the FVGC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/"
_URL_FILE = "fgvc-aircraft-2013b.tar.gz"

def __init__(
self,
root: str,
split: str = "trainval",
download: Optional[str] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
**kwargs: Any,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))

self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self._label_names = sorted(self._get_label_names(self._data_path))

# Parse the downloaded files
self._image_folder = os.path.join(self.root, self._split)
self._create_fgvc_aircrafts_disk_folder(self._data_path)

self._label_name_to_idx = dict(zip(self._label_names, range(len(self._label_names))))

self._image_files = []
self._labels = []
for label_name in self._label_names:
img_rel_folder = os.path.join(self._image_folder, label_name)
img_file_name_list = [
f for f in os.listdir(img_rel_folder) if os.path.isfile(os.path.join(img_rel_folder, f))
]
self._labels += [self._label_name_to_idx[label_name]] * len(img_file_name_list)
self._image_files += [os.path.join(img_rel_folder, img_name) for img_name in img_file_name_list]

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def _download(self):
"""
Download the FGVC Aircraft dataset archive and extract it under root.
"""
if self._check_exists():
return
download_and_extract_archive(self._URL + self._URL_FILE, self.root)

def _check_exists(self) -> bool:
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

def _create_fgvc_aircrafts_disk_folder(self, input_path: str):
img_data_folder = os.path.join(input_path, "data", "images")
labels_path = os.path.join(input_path, "data", f"images_variant_{self._split}.txt")
for label in self._label_names:
os.makedirs(os.path.join(self._image_folder, label), exist_ok=True)

with open(labels_path, "r") as labels_file:
lines = [line.strip() for line in labels_file]
for line in lines:
line_list = line.split(" ")
image_name = line_list[0]
label_name = self._parse_aircraft_name(" ".join(line_list[1:]))
shutil.copy(
src=os.path.join(img_data_folder, f"{image_name}.jpg"),
dst=os.path.join(self._image_folder, label_name),
)

def _get_label_names(self, input_path: str) -> List[str]:
variants_file = os.path.join(input_path, "data", "variants.txt")
with open(variants_file, "r") as f:
return [self._parse_aircraft_name(line.strip()) for line in f]

def _parse_aircraft_name(self, name: str) -> str:
return name.replace("/", "-").replace(" ", "-")