Skip to content

Commit e4a4a29

Browse files
pmeierprabhat00155
andauthored
streamline category file generation for prototype datasets (#4642)
* streamline category file generation for prototype datasets * cleanup Co-authored-by: Prabhat Roy <[email protected]>
1 parent 9bee9cc commit e4a4a29

File tree

6 files changed

+105
-62
lines changed

6 files changed

+105
-62
lines changed

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import pathlib
33
import re
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
import numpy as np
77
import torch
@@ -21,9 +21,7 @@
2121
OnlineResource,
2222
DatasetType,
2323
)
24-
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat
25-
26-
HERE = pathlib.Path(__file__).parent
24+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, BUILTIN_DIR, read_mat
2725

2826

2927
class Caltech101(Dataset):
@@ -32,7 +30,7 @@ def info(self) -> DatasetInfo:
3230
return DatasetInfo(
3331
"caltech101",
3432
type=DatasetType.IMAGE,
35-
categories=HERE / "caltech101.categories",
33+
categories=BUILTIN_DIR / "caltech101.categories",
3634
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
3735
)
3836

@@ -135,12 +133,11 @@ def _make_datapipe(
135133
)
136134
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
137135

138-
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
136+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
139137
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
140138
dp = TarArchiveReader(dp)
141139
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
142-
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
143-
create_categories_file(HERE, self.name, sorted(dir_names))
140+
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
144141

145142

146143
class Caltech256(Dataset):
@@ -149,7 +146,7 @@ def info(self) -> DatasetInfo:
149146
return DatasetInfo(
150147
"caltech256",
151148
type=DatasetType.IMAGE,
152-
categories=HERE / "caltech256.categories",
149+
categories=BUILTIN_DIR / "caltech256.categories",
153150
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
154151
)
155152

@@ -192,17 +189,8 @@ def _make_datapipe(
192189
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
193190
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
194191

195-
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
192+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
196193
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
197194
dp = TarArchiveReader(dp)
198195
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
199-
categories = [name.split(".")[1] for name in sorted(dir_names)]
200-
create_categories_file(HERE, self.name, categories)
201-
202-
203-
if __name__ == "__main__":
204-
from torchvision.prototype.datasets import home
205-
206-
root = home()
207-
Caltech101().generate_categories_file(root)
208-
Caltech256().generate_categories_file(root)
196+
return [name.split(".")[1] for name in sorted(dir_names)]

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@
2424
DatasetType,
2525
)
2626
from torchvision.prototype.datasets.utils._internal import (
27-
create_categories_file,
2827
INFINITE_BUFFER_SIZE,
28+
BUILTIN_DIR,
2929
image_buffer_from_array,
3030
path_comparator,
3131
)
3232

3333
__all__ = ["Cifar10", "Cifar100"]
3434

35-
HERE = pathlib.Path(__file__).parent
36-
3735

3836
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
3937
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
@@ -95,13 +93,12 @@ def _make_datapipe(
9593
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
9694
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
9795

98-
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
96+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
9997
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
10098
dp = TarArchiveReader(dp)
10199
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
102100
dp: IterDataPipe = Mapper(dp, self._unpickle)
103-
categories = next(iter(dp))[self._CATEGORIES_KEY]
104-
create_categories_file(HERE, self.name, categories)
101+
return next(iter(dp))[self._CATEGORIES_KEY]
105102

106103

107104
class Cifar10(_CifarBase):
@@ -118,7 +115,7 @@ def info(self) -> DatasetInfo:
118115
return DatasetInfo(
119116
"cifar10",
120117
type=DatasetType.RAW,
121-
categories=HERE / "cifar10.categories",
118+
categories=BUILTIN_DIR / "cifar10.categories",
122119
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
123120
)
124121

@@ -145,7 +142,7 @@ def info(self) -> DatasetInfo:
145142
return DatasetInfo(
146143
"cifar100",
147144
type=DatasetType.RAW,
148-
categories=HERE / "cifar100.categories",
145+
categories=BUILTIN_DIR / "cifar100.categories",
149146
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
150147
valid_options=dict(
151148
split=("train", "test"),
@@ -159,11 +156,3 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
159156
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
160157
)
161158
]
162-
163-
164-
if __name__ == "__main__":
165-
from torchvision.prototype.datasets import home
166-
167-
root = home()
168-
Cifar10().generate_categories_file(root)
169-
Cifar100().generate_categories_file(root)

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import pathlib
33
import re
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
import numpy as np
77
import torch
@@ -24,24 +24,22 @@
2424
DatasetType,
2525
)
2626
from torchvision.prototype.datasets.utils._internal import (
27-
create_categories_file,
2827
INFINITE_BUFFER_SIZE,
28+
BUILTIN_DIR,
2929
read_mat,
3030
getitem,
3131
path_accessor,
3232
path_comparator,
3333
)
3434

35-
HERE = pathlib.Path(__file__).parent
36-
3735

3836
class SBD(Dataset):
3937
@property
4038
def info(self) -> DatasetInfo:
4139
return DatasetInfo(
4240
"sbd",
4341
type=DatasetType.IMAGE,
44-
categories=HERE / "caltech256.categories",
42+
categories=BUILTIN_DIR / "caltech256.categories",
4543
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
4644
valid_options=dict(
4745
split=("train", "val", "train_noval"),
@@ -158,7 +156,7 @@ def _make_datapipe(
158156
)
159157
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
160158

161-
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
159+
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
162160
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
163161
dp = TarArchiveReader(dp)
164162
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
@@ -172,15 +170,4 @@ def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
172170
# the first and last line contain no information
173171
for line in lines[1:-1]
174172
]
175-
categories = tuple(
176-
zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1])))
177-
)[0]
178-
179-
create_categories_file(HERE, self.name, categories)
180-
181-
182-
if __name__ == "__main__":
183-
from torchvision.prototype.datasets import home
184-
185-
root = home()
186-
SBD().generate_categories_file(root)
173+
return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import argparse
2+
import sys
3+
import unittest.mock
4+
import warnings
5+
6+
with warnings.catch_warnings():
7+
warnings.filterwarnings("ignore", message=r"The categories file .+? does not exist.", category=UserWarning)
8+
9+
from torchvision.prototype import datasets
10+
11+
from torchvision.prototype.datasets._api import find
12+
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
13+
14+
15+
def main(*names, force=False):
16+
root = datasets.home()
17+
18+
for name in names:
19+
file = BUILTIN_DIR / f"{name}.categories"
20+
if file.exists() and not force:
21+
continue
22+
23+
dataset = find(name)
24+
try:
25+
with unittest.mock.patch(
26+
"torchvision.prototype.datasets.utils._dataset.DatasetInfo._read_categories_file", return_value=[]
27+
):
28+
categories = dataset._generate_categories(root)
29+
except NotImplementedError:
30+
continue
31+
32+
with open(file, "w") as fh:
33+
fh.write("\n".join(categories) + "\n")
34+
35+
36+
def parse_args(argv=None):
37+
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
38+
39+
parser.add_argument(
40+
"names",
41+
nargs="?",
42+
type=str,
43+
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
44+
)
45+
parser.add_argument(
46+
"-f",
47+
"--force",
48+
action="store_true",
49+
help="Force regeneration of category files.",
50+
)
51+
52+
args = parser.parse_args(argv or sys.argv[1:])
53+
54+
if not args.names:
55+
args.names = datasets.list()
56+
57+
return args
58+
59+
60+
if __name__ == "__main__":
61+
args = parse_args()
62+
63+
try:
64+
main(*args.names, force=args.force)
65+
except Exception as error:
66+
msg = str(error)
67+
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
68+
sys.exit(1)

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pathlib
66
import textwrap
7+
import warnings
78
from collections import Mapping
89
from typing import (
910
Any,
@@ -117,8 +118,7 @@ def __init__(
117118
elif isinstance(categories, int):
118119
categories = [str(label) for label in range(categories)]
119120
elif isinstance(categories, (str, pathlib.Path)):
120-
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
121-
categories = [line.strip() for line in fh]
121+
categories = self._read_categories_file(pathlib.Path(categories).expanduser().resolve())
122122
self.categories = tuple(categories)
123123

124124
self.citation = citation
@@ -137,6 +137,17 @@ def __init__(
137137
)
138138
self._valid_options: Dict[str, Sequence] = valid_options
139139

140+
@staticmethod
141+
def _read_categories_file(path: pathlib.Path) -> List[str]:
142+
if not path.exists() or not path.is_file():
143+
warnings.warn(
144+
f"The categories file {path} does not exist. Continuing without loaded categories.", UserWarning
145+
)
146+
return []
147+
148+
with open(path, "r") as file:
149+
return [line.strip() for line in file]
150+
140151
@property
141152
def default_config(self) -> DatasetConfig:
142153
return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
@@ -219,3 +230,6 @@ def to_datapipe(
219230

220231
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
221232
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
233+
234+
def _generate_categories(self, root: pathlib.Path) -> Sequence[str]:
235+
raise NotImplementedError

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
__all__ = [
1717
"INFINITE_BUFFER_SIZE",
18+
"BUILTIN_DIR",
1819
"sequence_to_str",
1920
"add_suggestion",
20-
"create_categories_file",
2121
"read_mat",
2222
"image_buffer_from_array",
2323
"SequenceIterator",
@@ -35,6 +35,8 @@
3535
# pseudo-infinite until a true infinite buffer is supported by all datapipes
3636
INFINITE_BUFFER_SIZE = 1_000_000_000
3737

38+
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
39+
3840

3941
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
4042
if len(seq) == 1:
@@ -60,11 +62,6 @@ def add_suggestion(
6062
return f"{msg.strip()} {hint}"
6163

6264

63-
def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None:
64-
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
65-
fh.write("\n".join(categories) + "\n")
66-
67-
6865
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
6966
try:
7067
import scipy.io as sio

0 commit comments

Comments
 (0)