Skip to content

Commit ee2e800

Browse files
prabhat00155pmeier
authored andcommitted
[fbsync] Add GTSRB dataset (#5117)
Summary: * Added GTSRB dataset * Added unittest for GTSRB dataset * Apply suggestions from code review * More changes from code review * readd accidental removed line Reviewed By: sallysyw Differential Revision: D33479282 fbshipit-source-id: 0942e02e5c5459a05536cf49e256c5dcd50c7fec Co-authored-by: Philip Meier <[email protected]>
1 parent 3d4c915 commit ee2e800

File tree

4 files changed

+159
-0
lines changed

4 files changed

+159
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4848
FlyingChairs
4949
FlyingThings3D
5050
Food101
51+
GTSRB
5152
HD1K
5253
HMDB51
5354
ImageNet

test/test_datasets.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,5 +2275,55 @@ def inject_fake_data(self, tmpdir, config):
22752275
return num_samples
22762276

22772277

2278+
class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
2279+
DATASET_CLASS = datasets.GTSRB
2280+
FEATURE_TYPES = (PIL.Image.Image, int)
2281+
2282+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2283+
2284+
def inject_fake_data(self, tmpdir: str, config):
2285+
root_folder = os.path.join(tmpdir, "GTSRB")
2286+
os.makedirs(root_folder, exist_ok=True)
2287+
2288+
# Train data
2289+
train_folder = os.path.join(root_folder, "Training")
2290+
os.makedirs(train_folder, exist_ok=True)
2291+
2292+
num_examples = 3
2293+
classes = ("00000", "00042", "00012")
2294+
for class_idx in classes:
2295+
datasets_utils.create_image_folder(
2296+
train_folder,
2297+
name=class_idx,
2298+
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
2299+
num_examples=num_examples,
2300+
)
2301+
2302+
total_number_of_examples = num_examples * len(classes)
2303+
# Test data
2304+
test_folder = os.path.join(root_folder, "Final_Test", "Images")
2305+
os.makedirs(test_folder, exist_ok=True)
2306+
2307+
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
2308+
csv_file.write("Filename;Width;Height;Roi.X1;Roi.Y1;Roi.X2;Roi.Y2;ClassId\n")
2309+
2310+
for _ in range(total_number_of_examples):
2311+
image_file = datasets_utils.create_random_string(5, string.digits) + ".ppm"
2312+
datasets_utils.create_image_file(test_folder, image_file)
2313+
row = [
2314+
image_file,
2315+
torch.randint(1, 100, size=()).item(),
2316+
torch.randint(1, 100, size=()).item(),
2317+
torch.randint(1, 100, size=()).item(),
2318+
torch.randint(1, 100, size=()).item(),
2319+
torch.randint(1, 100, size=()).item(),
2320+
torch.randint(1, 100, size=()).item(),
2321+
torch.randint(0, 43, size=()).item(),
2322+
]
2323+
csv_file.write(";".join(map(str, row)) + "\n")
2324+
2325+
return total_number_of_examples
2326+
2327+
22782328
if __name__ == "__main__":
22792329
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .flickr import Flickr8k, Flickr30k
1111
from .folder import ImageFolder, DatasetFolder
1212
from .food101 import Food101
13+
from .gtsrb import GTSRB
1314
from .hmdb51 import HMDB51
1415
from .imagenet import ImageNet
1516
from .inaturalist import INaturalist
@@ -83,4 +84,5 @@
8384
"Food101",
8485
"DTD",
8586
"FER2013",
87+
"GTSRB",
8688
)

torchvision/datasets/gtsrb.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import csv
2+
import os
3+
from typing import Any, Callable, Optional, Tuple
4+
5+
import PIL
6+
7+
from .folder import make_dataset
8+
from .utils import download_and_extract_archive
9+
from .vision import VisionDataset
10+
11+
12+
class GTSRB(VisionDataset):
13+
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
14+
15+
Args:
16+
root (string): Root directory of the dataset.
17+
train (bool, optional): If True, creates dataset from training set, otherwise
18+
creates from test set.
19+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
20+
version. E.g, ``transforms.RandomCrop``.
21+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22+
download (bool, optional): If True, downloads the dataset from the internet and
23+
puts it in root directory. If dataset is already downloaded, it is not
24+
downloaded again.
25+
"""
26+
27+
# Ground Truth for the test set
28+
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
29+
_gt_csv = "GT-final_test.csv"
30+
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
31+
32+
# URLs for the test and train set
33+
_urls = (
34+
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
35+
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
36+
)
37+
38+
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
39+
40+
def __init__(
41+
self,
42+
root: str,
43+
train: bool = True,
44+
transform: Optional[Callable] = None,
45+
target_transform: Optional[Callable] = None,
46+
download: bool = False,
47+
) -> None:
48+
49+
super().__init__(root, transform=transform, target_transform=target_transform)
50+
51+
self.root = os.path.expanduser(root)
52+
53+
self.train = train
54+
55+
self._base_folder = os.path.join(self.root, type(self).__name__)
56+
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
57+
58+
if download:
59+
self.download()
60+
61+
if not self._check_exists():
62+
raise RuntimeError("Dataset not found. You can use download=True to download it")
63+
64+
if train:
65+
samples = make_dataset(self._target_folder, extensions=(".ppm",))
66+
else:
67+
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
68+
samples = [
69+
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
70+
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
71+
]
72+
73+
self._samples = samples
74+
self.transform = transform
75+
self.target_transform = target_transform
76+
77+
def __len__(self) -> int:
78+
return len(self._samples)
79+
80+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
81+
82+
path, target = self._samples[index]
83+
sample = PIL.Image.open(path).convert("RGB")
84+
85+
if self.transform is not None:
86+
sample = self.transform(sample)
87+
88+
if self.target_transform is not None:
89+
target = self.target_transform(target)
90+
91+
return sample, target
92+
93+
def _check_exists(self) -> bool:
94+
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
95+
96+
def download(self) -> None:
97+
if self._check_exists():
98+
return
99+
100+
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
101+
102+
if not self.train:
103+
# Download Ground Truth for the test set
104+
download_and_extract_archive(
105+
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
106+
)

0 commit comments

Comments
 (0)