Skip to content

Commit 7b40f9c

Browse files
committed
Implementation for cityscapes in proto datasets
1 parent e13206d commit 7b40f9c

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .caltech import Caltech101, Caltech256
22
from .celeba import CelebA
33
from .cifar import Cifar10, Cifar100
4+
from .cityscapes import Cityscapes
45
from .clevr import CLEVR
56
from .coco import Coco
67
from .cub200 import CUB200
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from functools import partial
2+
from pathlib import Path
3+
from typing import Any, Dict, List
4+
5+
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, Demultiplexer, IterKeyZipper, JsonParser
6+
from torchvision.prototype.datasets.utils import (
7+
Dataset,
8+
DatasetInfo,
9+
DatasetConfig,
10+
ManualDownloadResource,
11+
OnlineResource,
12+
)
13+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
14+
from torchvision.prototype.features import EncodedImage
15+
16+
17+
class CityscapesDatasetInfo(DatasetInfo):
18+
def __init__(self, *args: Any, **kwargs: Any):
19+
super().__init__(*args, **kwargs)
20+
self._configs = tuple(
21+
config
22+
for config in self._configs
23+
if not (
24+
(config.split == "test" and config.mode == "coarse")
25+
or (config.split == "train_extra" and config.mode == "fine")
26+
)
27+
)
28+
29+
def make_config(self, **options: Any) -> DatasetConfig:
30+
config = super().make_config(**options)
31+
if config.split == "test" and config.mode == "coarse":
32+
raise ValueError("`split='test'` is only available for `mode='fine'`")
33+
if config.split == "train_extra" and config.mode == "fine":
34+
raise ValueError("`split='train_extra'` is only available for `mode='coarse'`")
35+
36+
return config
37+
38+
39+
class CityscapesResource(ManualDownloadResource):
40+
def __init__(self, **kwargs: Any) -> None:
41+
super().__init__(
42+
"Register on https://www.cityscapes-dataset.com/login/ and follow the instructions there.", **kwargs
43+
)
44+
45+
46+
class Cityscapes(Dataset):
47+
def _make_info(self) -> DatasetInfo:
48+
name = "cityscapes"
49+
categories = None
50+
51+
return CityscapesDatasetInfo(
52+
name,
53+
categories=categories,
54+
homepage="http://www.cityscapes-dataset.com/",
55+
valid_options=dict(
56+
split=("train", "val", "test", "train_extra"),
57+
mode=("fine", "coarse"),
58+
# target_type=("instance", "semantic", "polygon", "color")
59+
),
60+
)
61+
62+
_FILES_CHECKSUMS = {
63+
"gtCoarse.zip": "3555e09349ed49127053d940eaa66a87a79a175662b329c1a26a58d47e602b5b",
64+
"gtFine_trainvaltest.zip": "40461a50097844f400fef147ecaf58b18fd99e14e4917fb7c3bf9c0d87d95884",
65+
"leftImg8bit_trainextra.zip": "e41cc14c0c06aad051d52042465d9b8c22bacf6e4c93bb98de273ed7177b7133",
66+
"leftImg8bit_trainvaltest.zip": "3ccff9ac1fa1d80a6a064407e589d747ed0657aac7dc495a4403ae1235a37525",
67+
}
68+
69+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
70+
if config.mode == "fine":
71+
resources = [
72+
CityscapesResource(
73+
file_name="leftImg8bit_trainvaltest.zip",
74+
sha256=self._FILES_CHECKSUMS["leftImg8bit_trainvaltest.zip"],
75+
),
76+
CityscapesResource(
77+
file_name="gtFine_trainvaltest.zip", sha256=self._FILES_CHECKSUMS["gtFine_trainvaltest.zip"]
78+
),
79+
]
80+
else:
81+
resources = [
82+
CityscapesResource(
83+
file_name="leftImg8bit_trainextra.zip", sha256=self._FILES_CHECKSUMS["leftImg8bit_trainextra.zip"]
84+
),
85+
CityscapesResource(file_name="gtCoarse.zip", sha256=self._FILES_CHECKSUMS["gtCoarse.zip"]),
86+
]
87+
return resources
88+
89+
def _filter_split_images(self, data, *, req_split: str):
90+
path = Path(data[0])
91+
split = path.parent.parts[-2]
92+
return split == req_split and ".png" == path.suffix
93+
94+
def _filter_classify_targets(self, data, *, req_split: str):
95+
path = Path(data[0])
96+
name = path.name
97+
split = path.parent.parts[-2]
98+
if split != req_split:
99+
return None
100+
for i, target_type in enumerate(["instance", "label", "polygon", "color"]):
101+
ext = ".json" if target_type == "polygon" else ".png"
102+
if ext in path.suffix and target_type in name:
103+
return i
104+
return None
105+
106+
def _prepare_sample(self, data):
107+
(img_path, img_data), target_data = data
108+
109+
color_path, color_data = target_data[1]
110+
target_data = target_data[0]
111+
polygon_path, polygon_data = target_data[1]
112+
target_data = target_data[0]
113+
label_path, label_data = target_data[1]
114+
target_data = target_data[0]
115+
instance_path, instance_data = target_data
116+
117+
return dict(
118+
image_path=img_path,
119+
image=EncodedImage.from_file(img_data),
120+
color_path=color_path,
121+
color=EncodedImage.from_file(color_data),
122+
polygon_path=polygon_path,
123+
polygon=polygon_data,
124+
segmentation_path=label_path,
125+
segmentation=EncodedImage.from_file(label_data),
126+
instances_path=color_path,
127+
instances=EncodedImage.from_file(instance_data),
128+
)
129+
130+
def _make_datapipe(
131+
self,
132+
resource_dps: List[IterDataPipe],
133+
*,
134+
config: DatasetConfig,
135+
) -> IterDataPipe[Dict[str, Any]]:
136+
archive_images, archive_targets = resource_dps
137+
138+
images_dp = Filter(archive_images, filter_fn=partial(self._filter_split_images, req_split=config.split))
139+
140+
targets_dps = Demultiplexer(
141+
archive_targets,
142+
4,
143+
classifier_fn=partial(self._filter_classify_targets, req_split=config.split),
144+
drop_none=True,
145+
buffer_size=INFINITE_BUFFER_SIZE,
146+
)
147+
148+
# targets_dps[2] is for json polygon, we have to decode them
149+
targets_dps[2] = JsonParser(targets_dps[2])
150+
151+
def img_key_fn(data):
152+
stem = Path(data[0]).stem
153+
stem = stem[: -len("_leftImg8bit")]
154+
return stem
155+
156+
def target_key_fn(data, level=0):
157+
path = data[0]
158+
for _ in range(level):
159+
path = path[0]
160+
stem = Path(path).stem
161+
i = stem.rfind("_gt")
162+
stem = stem[:i]
163+
return stem
164+
165+
zipped_targets_dp = targets_dps[0]
166+
for level, data_dp in enumerate(targets_dps[1:]):
167+
zipped_targets_dp = IterKeyZipper(
168+
zipped_targets_dp,
169+
data_dp,
170+
key_fn=partial(target_key_fn, level=level),
171+
ref_key_fn=target_key_fn,
172+
buffer_size=INFINITE_BUFFER_SIZE,
173+
)
174+
175+
samples = IterKeyZipper(
176+
images_dp,
177+
zipped_targets_dp,
178+
key_fn=img_key_fn,
179+
ref_key_fn=partial(target_key_fn, level=len(targets_dps) - 1),
180+
buffer_size=INFINITE_BUFFER_SIZE,
181+
)
182+
return Mapper(samples, fn=self._prepare_sample)

0 commit comments

Comments
 (0)