@@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below.
12
12
13
13
Before we start with the actual implementation, you should create a module in ` torchvision/prototype/datasets/_builtin `
14
14
that hints at the dataset you are going to add. For example ` caltech.py ` for ` caltech101 ` and ` caltech256 ` . In that
15
- module create a class that inherits from ` datasets.utils.Dataset ` and overwrites at minimum three methods that will be
16
- discussed in detail below:
15
+ module create a class that inherits from ` datasets.utils.Dataset ` and overwrites four methods that will be discussed in
16
+ detail below:
17
17
18
18
``` python
19
- from typing import Any, Dict, List
19
+ import pathlib
20
+ from typing import Any, BinaryIO, Dict, List, Tuple, Union
20
21
21
22
from torchdata.datapipes.iter import IterDataPipe
22
- from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource
23
+ from torchvision.prototype.datasets.utils import Dataset, OnlineResource
23
24
25
+ from .._api import register_dataset, register_info
26
+
27
+ NAME = " my-dataset"
28
+
29
+ @register_info (NAME )
30
+ def _info () -> Dict[str , Any]:
31
+ return dict (
32
+ ...
33
+ )
34
+
35
+ @register_dataset (NAME )
24
36
class MyDataset (Dataset ):
25
- def _make_info (self ) -> DatasetInfo :
37
+ def __init__ (self , root : Union[ str , pathlib.Path], * , ..., skip_integrity_check : bool = False ) -> None :
26
38
...
39
+ super ().__init__ (root, skip_integrity_check = skip_integrity_check)
27
40
28
- def resources (self , config : DatasetConfig ) -> List[OnlineResource]:
41
+ def _resources (self ) -> List[OnlineResource]:
29
42
...
30
43
31
- def _make_datapipe (
32
- self , resource_dps : List[IterDataPipe], * , config : DatasetConfig,
33
- ) -> IterDataPipe[Dict[str , Any]]:
44
+ def _datapipe (self , resource_dps : List[IterDataPipe[Tuple[str , BinaryIO]]]) -> IterDataPipe[Dict[str , Any]]:
45
+ ...
46
+
47
+ def __len__ (self ) -> int :
34
48
...
35
49
```
36
50
37
- ### ` _make_info(self) `
51
+ In addition to the dataset, you also need to implement an ` _info() ` function that takes no arguments and returns a
52
+ dictionary of static information. The most common use case is to provide human-readable categories.
53
+ [ See below] ( #how-do-i-handle-a-dataset-that-defines-many-categories ) how to handle cases with many categories.
38
54
39
- The ` DatasetInfo ` carries static information about the dataset. There are two required fields:
55
+ Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
56
+ With that they are loadable through ` datasets.load("my-dataset") ` and ` datasets.info("my-dataset") ` , respectively.
40
57
41
- - ` name ` : Name of the dataset. This will be used to load the dataset with ` datasets.load(name) ` . Should only contain
42
- lowercase characters.
58
+ ### ` __init__(self, root, *, ..., skip_integrity_check = False) `
43
59
44
- There are more optional parameters that can be passed:
60
+ Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
61
+ base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
62
+ setting them as instance attributes has to happen before the call of ` super().__init__(...) ` , because that will invoke
63
+ the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
64
+ an underscore.
45
65
46
- - ` dependencies ` : Collection of third-party dependencies that are needed to load the dataset, e.g. ` ("scipy",) ` . Their
47
- availability will be automatically checked if a user tries to load the dataset. Within the implementation, import
48
- these packages lazily to avoid missing dependencies at import time.
49
- - ` categories ` : Sequence of human-readable category names for each label. The index of each category has to match the
50
- corresponding label returned in the dataset samples.
51
- [ See below] ( #how-do-i-handle-a-dataset-that-defines-many-categories ) how to handle cases with many categories.
52
- - ` valid_options ` : Configures valid options that can be passed to the dataset. It should be ` Dict[str, Sequence[Any]] ` .
53
- The options are accessible through the ` config ` namespace in the other two functions. First value of the sequence is
54
- taken as default if the user passes no option to ` torchvision.prototype.datasets.load() ` .
66
+ If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
67
+ class constructor, e.g. ` super().__init__(..., dependencies=("scipy",)) ` . Their availability will be automatically
68
+ checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
69
+ avoid missing dependencies at import time.
55
70
56
- ## ` resources (self, config )`
71
+ ### ` _resources (self)`
57
72
58
- Returns ` List[datasets.utils.OnlineResource] ` of all the files that need to be present locally before the dataset with a
59
- specific ` config ` can be build. The download will happen automatically.
73
+ Returns ` List[datasets.utils.OnlineResource] ` of all the files that need to be present locally before the dataset can be
74
+ build. The download will happen automatically.
60
75
61
76
Currently, the following ` OnlineResource ` 's are supported:
62
77
@@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024):
81
96
print (checksum.hexdigest())
82
97
```
83
98
84
- ### ` _make_datapipe(resource_dps, *, config )`
99
+ ### ` _datapipe(self, resource_dps )`
85
100
86
101
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
87
102
to the current stable datasets is that everything is performed through ` IterDataPipe ` 's. From the perspective of someone
@@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f
99
114
needs extra arguments. If the provided ` IterDataPipe ` 's are not sufficient for the use case, it is also not complicated
100
115
to add one. See the MNIST or CelebA datasets for example.
101
116
102
- ` make_datapipe ()` receives ` resource_dps ` , which is a list of datapipes that has a 1-to-1 correspondence with the return
103
- value of ` resources ()` . In case of archives with regular suffixes (` .tar ` , ` .zip ` , ...), the datapipe will contain
104
- tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one
117
+ ` _datapipe ()` receives ` resource_dps ` , which is a list of datapipes that has a 1-to-1 correspondence with the return
118
+ value of ` _resources ()` . In case of archives with regular suffixes (` .tar ` , ` .zip ` , ...), the datapipe will contain
119
+ tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
105
120
of such tuples for the file specified by the resource.
106
121
107
122
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. ` IterKeyZipper ` and
108
- ` Grouper ` . There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since
109
- most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered
110
- that is only needed at runtime.
123
+ ` Grouper ` . There are two issues with that:
124
+
125
+ 1 . If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
126
+ 2 . This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
111
127
112
128
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
113
129
trying to zip already loaded images.
114
130
115
131
There are two special datapipes that are not used through their class, but through the functions ` hint_shuffling ` and
116
- ` hint_sharding ` . As the name implies they only hint part in the datapipe graph where shuffling and sharding should take
117
- place, but are no-ops by default. They can be imported from ` torchvision.prototype.datasets.utils._internal ` and are
118
- required in each dataset. ` hint_shuffling ` has to be placed before ` hint_sharding ` .
132
+ ` hint_sharding ` . As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
133
+ should take place, but are no-ops by default. They can be imported from ` torchvision.prototype.datasets.utils._internal `
134
+ and are required in each dataset. ` hint_shuffling ` has to be placed before ` hint_sharding ` .
119
135
120
136
Finally, each item in the final datapipe should be a dictionary with ` str ` keys. There is no standardization of the
121
137
names (yet!).
122
138
139
+ ### ` __len__ `
140
+
141
+ This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
142
+ [ underscores] ( https://peps.python.org/pep-0515/ ) after every three digits starting from the right to enhance the
143
+ readability. For example, ` 1_281_167 ` vs. ` 1281167 ` .
144
+
145
+ If there are only two different numbers, a simple ` if ` / ` else ` is fine:
146
+
147
+ ``` py
148
+ def __len__ (self ):
149
+ return 12_345 if self ._split == " train" else 6_789
150
+ ```
151
+
152
+ If there are more options, using a dictionary usually is the most readable option:
153
+
154
+ ``` py
155
+ def __len__ (self ):
156
+ return {
157
+ " train" : 3 ,
158
+ " val" : 2 ,
159
+ " test" : 1 ,
160
+ }[self ._split]
161
+ ```
162
+
163
+ If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
164
+
165
+ ``` py
166
+ def __len__ (self ):
167
+ return {
168
+ (" train" , " bar" ): 4 ,
169
+ (" train" , " baz" ): 3 ,
170
+ (" test" , " bar" ): 2 ,
171
+ (" test" , " baz" ): 1 ,
172
+ }[(self ._split, self ._foo)]
173
+ ```
174
+
175
+ The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
176
+ development process. Since it is an ` @abstractmethod ` you still have to implement it from the start. The canonical way
177
+ is to define a dummy method like
178
+
179
+ ``` py
180
+ def __len__ (self ):
181
+ return 1
182
+ ```
183
+
184
+ and only fill it with the correct data if the implementation is otherwise finished.
185
+ [ See below] ( #how-do-i-compute-the-number-of-samples ) for a possible way to compute the number of samples.
186
+
123
187
## Tests
124
188
125
189
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
126
190
This mock-up should resemble the original data as close as necessary, while containing only few examples.
127
191
128
192
To do this, add a new function in [ ` test/builtin_dataset_mocks.py ` ] ( ../../../../test/builtin_dataset_mocks.py ) with the
129
- same name as you have defined in ` _make_config() ` (if the name includes hyphens ` - ` , replace them with underscores ` _ ` )
130
- and decorate it with ` @register_mock ` :
193
+ same name as you have used in ` @register_info ` and ` @register_dataset ` . This function is called "mock data function".
194
+ Decorate it with ` @register_mock(configs=[dict(...), ...]) ` . Each dictionary denotes one configuration that the dataset
195
+ will be loaded with, e.g. ` datasets.load("my-dataset", **config) ` . For the most common case of a product of all options,
196
+ you can use the ` combinations_grid() ` helper function, e.g.
197
+ ` configs=combinations_grid(split=("train", "test"), foo=("bar", "baz")) ` .
198
+
199
+ In case the name of the dataset includes hyphens ` - ` , replace them with underscores ` _ ` in the function name and pass
200
+ the ` name ` parameter to ` @register_mock `
131
201
132
202
``` py
133
203
# this is defined in torchvision/prototype/datasets/_builtin
204
+ @register_dataset (" my-dataset" )
134
205
class MyDataset (Dataset ):
135
- def _make_info (self ) -> DatasetInfo:
136
- return DatasetInfo(
137
- " my-dataset" ,
138
- ...
139
- )
140
-
141
- @register_mock
142
- def my_dataset (info , root , config ):
206
+ ...
207
+
208
+ @register_mock (name = " my-dataset" , configs = ... )
209
+ def my_dataset (root , config ):
143
210
...
144
211
```
145
212
146
- The function receives three arguments:
213
+ The mock data function receives two arguments:
147
214
148
- - ` info ` : The return value of ` _make_info() ` .
149
215
- ` root ` : A [ ` pathlib.Path ` ] ( https://docs.python.org/3/library/pathlib.html#pathlib.Path ) of a folder, in which the data
150
216
needs to be placed.
151
- - ` config ` : The configuration to generate the data for. This is the same value that ` _make_datapipe() ` receives.
217
+ - ` config ` : The configuration to generate the data for. This is one of the dictionaries defined in
218
+ ` @register_mock(configs=...) `
152
219
153
220
The function should generate all files that are needed for the current ` config ` . Each file should be complete, e.g. if
154
- the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current
155
- ` config ` . Although this seems odd at first, this is important. Consider the following original data setup:
221
+ the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
222
+ the current ` config ` . Although this seems odd at first, this is important. Consider the following original data setup:
156
223
157
224
```
158
225
root
167
234
For map-style datasets (like the one currently in ` torchvision.datasets ` ), one explicitly selects the files they want to
168
235
load. For example, something like ` (root / split).iterdir() ` works fine even if only the specific split folder is
169
236
present. With iterable-style datasets though, we get something like ` root.iterdir() ` from ` resource_dps ` in
170
- ` _make_datapipe() ` and need to manually ` Filter ` it to only keep the files we want. If we would only generate the data
171
- for the current ` config ` , the test would also pass if the dataset is missing the filtering, but would fail on the real
172
- data.
237
+ ` _datapipe() ` and need to manually ` Filter ` it to only keep the files we want. If we would only generate the data for
238
+ the current ` config ` , the test would also pass if the dataset is missing the filtering, but would fail on the real data.
173
239
174
240
For datasets that are ported from the old API, we already have some mock data in
175
241
[ ` test/test_datasets.py ` ] ( ../../../../test/test_datasets.py ) . You can find the test case corresponding test case there
@@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences
178
244
- ` tmp_dir ` corresponds to ` root ` , but is a ` str ` rather than a
179
245
[ ` pathlib.Path ` ] ( https://docs.python.org/3/library/pathlib.html#pathlib.Path ) . Thus, you often see something like
180
246
` folder = pathlib.Path(tmp_dir) ` . This is not needed.
181
- - Although both parameters are called ` config ` , the value in the new tests is a namespace. Thus, please use ` config.foo `
182
- over ` config["foo"] ` to enhance readability.
183
247
- The data generated by ` inject_fake_data ` was supposed to be in an extracted state. This is no longer the case for the
184
248
new mock-ups. Thus, you need to use helper functions like ` make_zip ` or ` make_tar ` to actually generate the files
185
249
specified in the dataset.
@@ -196,17 +260,17 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets
196
260
197
261
### How do I start?
198
262
199
- Get the skeleton of your dataset class ready with all 3 methods. For ` _make_datapipe ()` , you can just do
263
+ Get the skeleton of your dataset class ready with all 4 methods. For ` _datapipe ()` , you can just do
200
264
` return resources_dp[0] ` to get started. Then import the dataset class in
201
- ` torchvision/prototype/datasets/_builtin/__init__.py ` : this will automatically register the dataset and it will be
265
+ ` torchvision/prototype/datasets/_builtin/__init__.py ` : this will automatically register the dataset, and it will be
202
266
instantiable via ` datasets.load("mydataset") ` . On a separate script, try something like
203
267
204
268
``` py
205
269
from torchvision.prototype import datasets
206
270
207
271
dataset = datasets.load(" mydataset" )
208
272
for sample in dataset:
209
- print (sample) # this is the content of an item in datapipe returned by _make_datapipe ()
273
+ print (sample) # this is the content of an item in datapipe returned by _datapipe ()
210
274
break
211
275
# Or you can also inspect the sample in a debugger
212
276
```
@@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format.
217
281
218
282
### How do I handle a dataset that defines many categories?
219
283
220
- As a rule of thumb, ` datasets.utils.DatasetInfo(..., categories=) ` should only be set directly for ten categories or
221
- fewer. If more categories are needed, you can add a ` $NAME.categories ` file to the ` _builtin ` folder in which each line
222
- specifies a category. If ` $NAME ` matches the name of the dataset (which it definitively should!) it will be
223
- automatically loaded if ` categories= ` is not set .
284
+ As a rule of thumb, ` categories ` in the info dictionary should only be set manually for ten categories or fewer. If more
285
+ categories are needed, you can add a ` $NAME.categories ` file to the ` _builtin ` folder in which each line specifies a
286
+ category. To load such a file, use the ` from torchvision.prototype.datasets.utils._internal import read_categories_file `
287
+ function and pass it ` $NAME ` .
224
288
225
289
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
226
- each folder denotes the name of the category, the dataset can overwrite the ` _generate_categories ` method. It gets
227
- passed the ` root ` path to the resources, but they have to be manually loaded, e.g.
228
- ` self.resources(config)[0].load(root) ` . The method should return a sequence of strings representing the category names.
290
+ each folder denotes the name of the category, the dataset can overwrite the ` _generate_categories ` method. The method
291
+ should return a sequence of strings representing the category names. In the method body, you'll have to manually load
292
+ the resources, e.g.
293
+
294
+ ``` py
295
+ resources = self ._resources()
296
+ dp = resources[0 ].load(self ._root)
297
+ ```
298
+
299
+ Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
300
+ sense and afterwards materialize the data with ` next(iter(dp)) ` or ` list(dp) ` and proceed with that.
301
+
229
302
To generate the ` $NAME.categories ` file, run ` python -m torchvision.prototype.datasets.generate_category_files $NAME ` .
230
303
231
304
### What if a resource file forms an I/O bottleneck?
@@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed.
235
308
` preprocess ` parameter that can be a ` Callable[[pathlib.Path], pathlib.Path] ` where the input points to the file to be
236
309
preprocessed and the return value should be the result of the preprocessing to load. For convenience, ` preprocess ` also
237
310
accepts ` "decompress" ` and ` "extract" ` to handle these common scenarios.
311
+
312
+ ### How do I compute the number of samples?
313
+
314
+ Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
315
+ than to iterate over the dataset and count the number of samples:
316
+
317
+ ``` py
318
+ import itertools
319
+ from torchvision.prototype import datasets
320
+
321
+
322
+ def combinations_grid (** kwargs ):
323
+ return [dict (zip (kwargs.keys(), values)) for values in itertools.product(* kwargs.values())]
324
+
325
+
326
+ # If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
327
+ configs = combinations_grid(split = (" train" , " test" ), foo = (" bar" , " baz" ))
328
+
329
+ for config in configs:
330
+ dataset = datasets.load(" my-dataset" , ** config)
331
+
332
+ num_samples = 0
333
+ for _ in dataset:
334
+ num_samples += 1
335
+
336
+ print (" , " .join(f " { key} = { value} " for key, value in config.items()), num_samples)
337
+ ```
338
+
339
+ To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
340
+ files.
0 commit comments