Skip to content

Add GTSRB dataset to the list of supporting datasets #5117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
GTSRB
HD1K
HMDB51
ImageNet
Expand Down
50 changes: 50 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,5 +2275,55 @@ def inject_fake_data(self, tmpdir, config):
return num_samples


class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))

def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "GTSRB")
os.makedirs(root_folder, exist_ok=True)

# Train data
train_folder = os.path.join(root_folder, "Training")
os.makedirs(train_folder, exist_ok=True)

num_examples = 3
classes = ("00000", "00042", "00012")
for class_idx in classes:
datasets_utils.create_image_folder(
train_folder,
name=class_idx,
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
num_examples=num_examples,
)

total_number_of_examples = num_examples * len(classes)
# Test data
test_folder = os.path.join(root_folder, "Final_Test", "Images")
os.makedirs(test_folder, exist_ok=True)

with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
csv_file.write("Filename;Width;Height;Roi.X1;Roi.Y1;Roi.X2;Roi.Y2;ClassId\n")

for _ in range(total_number_of_examples):
image_file = datasets_utils.create_random_string(5, string.digits) + ".ppm"
datasets_utils.create_image_file(test_folder, image_file)
row = [
image_file,
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(0, 43, size=()).item(),
]
csv_file.write(";".join(map(str, row)) + "\n")

return total_number_of_examples


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .inaturalist import INaturalist
Expand Down Expand Up @@ -83,4 +84,5 @@
"Food101",
"DTD",
"FER2013",
"GTSRB",
)
106 changes: 106 additions & 0 deletions torchvision/datasets/gtsrb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import csv
import os
from typing import Any, Callable, Optional, Tuple

import PIL

from .folder import make_dataset
from .utils import download_and_extract_archive
from .vision import VisionDataset


class GTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.

Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

# Ground Truth for the test set
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
_gt_csv = "GT-final_test.csv"
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"

# URLs for the test and train set
_urls = (
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
)

_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")

def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:

super().__init__(root, transform=transform, target_transform=target_transform)

self.root = os.path.expanduser(root)

self.train = train

self._base_folder = os.path.join(self.root, type(self).__name__)
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

if train:
samples = make_dataset(self._target_folder, extensions=(".ppm",))
else:
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
samples = [
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
]

self._samples = samples
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self._samples)

def __getitem__(self, index: int) -> Tuple[Any, Any]:

path, target = self._samples[index]
sample = PIL.Image.open(path).convert("RGB")

if self.transform is not None:
sample = self.transform(sample)

if self.target_transform is not None:
target = self.target_transform(target)

return sample, target

def _check_exists(self) -> bool:
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)

def download(self) -> None:
if self._check_exists():
return

download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])

if not self.train:
# Download Ground Truth for the test set
download_and_extract_archive(
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
)