Skip to content

Commit 493d301

Browse files
authored
add prototype for SBD dataset (#4537)
* add prototype for SBD dataset * cleanup
1 parent ba11155 commit 493d301

File tree

4 files changed

+201
-1
lines changed

4 files changed

+201
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .caltech import Caltech101, Caltech256
22
from .cifar import Cifar10, Cifar100
3+
from .sbd import SBD
34
from .voc import VOC
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
aeroplane
2+
bicycle
3+
bird
4+
boat
5+
bottle
6+
bus
7+
car
8+
cat
9+
chair
10+
cow
11+
diningtable
12+
dog
13+
horse
14+
motorbike
15+
person
16+
pottedplant
17+
sheep
18+
sofa
19+
train
20+
tvmonitor
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import io
2+
import pathlib
3+
import re
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
6+
import numpy as np
7+
import torch
8+
from torch.utils.data import IterDataPipe
9+
from torch.utils.data.datapipes.iter import Mapper, TarArchiveReader, Shuffler, Demultiplexer, Filter
10+
from torchdata.datapipes.iter import KeyZipper, LineReader
11+
from torchvision.prototype.datasets.utils import (
12+
Dataset,
13+
DatasetConfig,
14+
DatasetInfo,
15+
HttpResource,
16+
OnlineResource,
17+
DatasetType,
18+
)
19+
from torchvision.prototype.datasets.utils._internal import (
20+
create_categories_file,
21+
INFINITE_BUFFER_SIZE,
22+
read_mat,
23+
getitem,
24+
path_accessor,
25+
path_comparator,
26+
)
27+
28+
HERE = pathlib.Path(__file__).parent
29+
30+
31+
class SBD(Dataset):
32+
@property
33+
def info(self) -> DatasetInfo:
34+
return DatasetInfo(
35+
"sbd",
36+
type=DatasetType.IMAGE,
37+
categories=HERE / "caltech256.categories",
38+
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
39+
valid_options=dict(
40+
split=("train", "val", "train_noval"),
41+
boundaries=(True, False),
42+
segmentation=(False, True),
43+
),
44+
)
45+
46+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
47+
archive = HttpResource(
48+
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
49+
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
50+
)
51+
extra_split = HttpResource(
52+
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
53+
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
54+
)
55+
return [archive, extra_split]
56+
57+
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
58+
path = pathlib.Path(data[0])
59+
parent, grandparent, *_ = path.parents
60+
61+
if parent.name == "dataset":
62+
return 0
63+
elif grandparent.name == "dataset":
64+
if parent.name == "img":
65+
return 1
66+
elif parent.name == "cls":
67+
return 2
68+
else:
69+
return None
70+
else:
71+
return None
72+
73+
def _decode_ann(
74+
self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool
75+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
76+
raw_anns = data["GTcls"][0]
77+
raw_boundaries = raw_anns["Boundaries"][0]
78+
raw_segmentation = raw_anns["Segmentation"][0]
79+
80+
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
81+
boundaries = (
82+
torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
83+
if decode_boundaries
84+
else None
85+
)
86+
segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None
87+
88+
return boundaries, segmentation
89+
90+
def _collate_and_decode_sample(
91+
self,
92+
data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]],
93+
*,
94+
config: DatasetConfig,
95+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
96+
) -> Dict[str, Any]:
97+
split_and_image_data, ann_data = data
98+
_, image_data = split_and_image_data
99+
image_path, image_buffer = image_data
100+
ann_path, ann_buffer = ann_data
101+
102+
image = decoder(image_buffer) if decoder else image_buffer
103+
104+
if config.boundaries or config.segmentation:
105+
boundaries, segmentation = self._decode_ann(
106+
read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation
107+
)
108+
else:
109+
boundaries = segmentation = None
110+
111+
return dict(
112+
image_path=image_path,
113+
image=image,
114+
ann_path=ann_path,
115+
boundaries=boundaries,
116+
segmentation=segmentation,
117+
)
118+
119+
def _make_datapipe(
120+
self,
121+
resource_dps: List[IterDataPipe],
122+
*,
123+
config: DatasetConfig,
124+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
125+
) -> IterDataPipe[Dict[str, Any]]:
126+
archive_dp, extra_split_dp = resource_dps
127+
128+
archive_dp = resource_dps[0]
129+
archive_dp = TarArchiveReader(archive_dp)
130+
split_dp, images_dp, anns_dp = Demultiplexer(
131+
archive_dp,
132+
3,
133+
self._classify_archive, # type: ignore[arg-type]
134+
buffer_size=INFINITE_BUFFER_SIZE,
135+
drop_none=True,
136+
)
137+
138+
if config.split == "train_noval":
139+
split_dp = extra_split_dp
140+
split_dp = LineReader(split_dp, decode=True)
141+
split_dp = Shuffler(split_dp)
142+
143+
dp = split_dp
144+
for level, data_dp in enumerate((images_dp, anns_dp)):
145+
dp = KeyZipper(
146+
dp,
147+
data_dp,
148+
key_fn=getitem(*[0] * level, 1),
149+
ref_key_fn=path_accessor("stem"),
150+
buffer_size=INFINITE_BUFFER_SIZE,
151+
)
152+
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
153+
154+
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
155+
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
156+
dp = TarArchiveReader(dp)
157+
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
158+
dp = LineReader(dp)
159+
dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1)
160+
lines = tuple(zip(*iter(dp)))[1]
161+
162+
pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)")
163+
categories_and_labels = [
164+
pattern.match(line).groups() # type: ignore[union-attr]
165+
# the first and last line contain no information
166+
for line in lines[1:-1]
167+
]
168+
categories = tuple(
169+
zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1])))
170+
)[0]
171+
172+
create_categories_file(HERE, self.name, categories)
173+
174+
175+
if __name__ == "__main__":
176+
from torchvision.prototype.datasets import home
177+
178+
root = home()
179+
SBD().generate_categories_file(root)

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import difflib
33
import io
44
import pathlib
5-
from typing import Collection, Sequence, Callable, Union, Iterator, Tuple, TypeVar, Dict, Any
5+
from typing import Collection, Sequence, Callable, Union, Any, Tuple, TypeVar, Iterator, Dict
66

77
import numpy as np
88
import PIL.Image

0 commit comments

Comments
 (0)