|
| 1 | +import io |
| 2 | +import pathlib |
| 3 | +import re |
| 4 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +from torch.utils.data import IterDataPipe |
| 9 | +from torch.utils.data.datapipes.iter import Mapper, TarArchiveReader, Shuffler, Demultiplexer, Filter |
| 10 | +from torchdata.datapipes.iter import KeyZipper, LineReader |
| 11 | +from torchvision.prototype.datasets.utils import ( |
| 12 | + Dataset, |
| 13 | + DatasetConfig, |
| 14 | + DatasetInfo, |
| 15 | + HttpResource, |
| 16 | + OnlineResource, |
| 17 | + DatasetType, |
| 18 | +) |
| 19 | +from torchvision.prototype.datasets.utils._internal import ( |
| 20 | + create_categories_file, |
| 21 | + INFINITE_BUFFER_SIZE, |
| 22 | + read_mat, |
| 23 | + getitem, |
| 24 | + path_accessor, |
| 25 | + path_comparator, |
| 26 | +) |
| 27 | + |
| 28 | +HERE = pathlib.Path(__file__).parent |
| 29 | + |
| 30 | + |
| 31 | +class SBD(Dataset): |
| 32 | + @property |
| 33 | + def info(self) -> DatasetInfo: |
| 34 | + return DatasetInfo( |
| 35 | + "sbd", |
| 36 | + type=DatasetType.IMAGE, |
| 37 | + categories=HERE / "caltech256.categories", |
| 38 | + homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", |
| 39 | + valid_options=dict( |
| 40 | + split=("train", "val", "train_noval"), |
| 41 | + boundaries=(True, False), |
| 42 | + segmentation=(False, True), |
| 43 | + ), |
| 44 | + ) |
| 45 | + |
| 46 | + def resources(self, config: DatasetConfig) -> List[OnlineResource]: |
| 47 | + archive = HttpResource( |
| 48 | + "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", |
| 49 | + sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", |
| 50 | + ) |
| 51 | + extra_split = HttpResource( |
| 52 | + "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", |
| 53 | + sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", |
| 54 | + ) |
| 55 | + return [archive, extra_split] |
| 56 | + |
| 57 | + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: |
| 58 | + path = pathlib.Path(data[0]) |
| 59 | + parent, grandparent, *_ = path.parents |
| 60 | + |
| 61 | + if parent.name == "dataset": |
| 62 | + return 0 |
| 63 | + elif grandparent.name == "dataset": |
| 64 | + if parent.name == "img": |
| 65 | + return 1 |
| 66 | + elif parent.name == "cls": |
| 67 | + return 2 |
| 68 | + else: |
| 69 | + return None |
| 70 | + else: |
| 71 | + return None |
| 72 | + |
| 73 | + def _decode_ann( |
| 74 | + self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool |
| 75 | + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| 76 | + raw_anns = data["GTcls"][0] |
| 77 | + raw_boundaries = raw_anns["Boundaries"][0] |
| 78 | + raw_segmentation = raw_anns["Segmentation"][0] |
| 79 | + |
| 80 | + # the boundaries are stored in sparse CSC format, which is not supported by PyTorch |
| 81 | + boundaries = ( |
| 82 | + torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) |
| 83 | + if decode_boundaries |
| 84 | + else None |
| 85 | + ) |
| 86 | + segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None |
| 87 | + |
| 88 | + return boundaries, segmentation |
| 89 | + |
| 90 | + def _collate_and_decode_sample( |
| 91 | + self, |
| 92 | + data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], |
| 93 | + *, |
| 94 | + config: DatasetConfig, |
| 95 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 96 | + ) -> Dict[str, Any]: |
| 97 | + split_and_image_data, ann_data = data |
| 98 | + _, image_data = split_and_image_data |
| 99 | + image_path, image_buffer = image_data |
| 100 | + ann_path, ann_buffer = ann_data |
| 101 | + |
| 102 | + image = decoder(image_buffer) if decoder else image_buffer |
| 103 | + |
| 104 | + if config.boundaries or config.segmentation: |
| 105 | + boundaries, segmentation = self._decode_ann( |
| 106 | + read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation |
| 107 | + ) |
| 108 | + else: |
| 109 | + boundaries = segmentation = None |
| 110 | + |
| 111 | + return dict( |
| 112 | + image_path=image_path, |
| 113 | + image=image, |
| 114 | + ann_path=ann_path, |
| 115 | + boundaries=boundaries, |
| 116 | + segmentation=segmentation, |
| 117 | + ) |
| 118 | + |
| 119 | + def _make_datapipe( |
| 120 | + self, |
| 121 | + resource_dps: List[IterDataPipe], |
| 122 | + *, |
| 123 | + config: DatasetConfig, |
| 124 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 125 | + ) -> IterDataPipe[Dict[str, Any]]: |
| 126 | + archive_dp, extra_split_dp = resource_dps |
| 127 | + |
| 128 | + archive_dp = resource_dps[0] |
| 129 | + archive_dp = TarArchiveReader(archive_dp) |
| 130 | + split_dp, images_dp, anns_dp = Demultiplexer( |
| 131 | + archive_dp, |
| 132 | + 3, |
| 133 | + self._classify_archive, # type: ignore[arg-type] |
| 134 | + buffer_size=INFINITE_BUFFER_SIZE, |
| 135 | + drop_none=True, |
| 136 | + ) |
| 137 | + |
| 138 | + if config.split == "train_noval": |
| 139 | + split_dp = extra_split_dp |
| 140 | + split_dp = LineReader(split_dp, decode=True) |
| 141 | + split_dp = Shuffler(split_dp) |
| 142 | + |
| 143 | + dp = split_dp |
| 144 | + for level, data_dp in enumerate((images_dp, anns_dp)): |
| 145 | + dp = KeyZipper( |
| 146 | + dp, |
| 147 | + data_dp, |
| 148 | + key_fn=getitem(*[0] * level, 1), |
| 149 | + ref_key_fn=path_accessor("stem"), |
| 150 | + buffer_size=INFINITE_BUFFER_SIZE, |
| 151 | + ) |
| 152 | + return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) |
| 153 | + |
| 154 | + def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: |
| 155 | + dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) |
| 156 | + dp = TarArchiveReader(dp) |
| 157 | + dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m")) |
| 158 | + dp = LineReader(dp) |
| 159 | + dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1) |
| 160 | + lines = tuple(zip(*iter(dp)))[1] |
| 161 | + |
| 162 | + pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)") |
| 163 | + categories_and_labels = [ |
| 164 | + pattern.match(line).groups() # type: ignore[union-attr] |
| 165 | + # the first and last line contain no information |
| 166 | + for line in lines[1:-1] |
| 167 | + ] |
| 168 | + categories = tuple( |
| 169 | + zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))) |
| 170 | + )[0] |
| 171 | + |
| 172 | + create_categories_file(HERE, self.name, categories) |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + from torchvision.prototype.datasets import home |
| 177 | + |
| 178 | + root = home() |
| 179 | + SBD().generate_categories_file(root) |
0 commit comments