Skip to content

Commit aba3ee0

Browse files
committed
enable additional arg forwarding
1 parent 170f700 commit aba3ee0

File tree

2 files changed

+55
-22
lines changed

2 files changed

+55
-22
lines changed

test/datasets_utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections.abc
12
import contextlib
23
import functools
34
import importlib
@@ -227,16 +228,27 @@ def test_baz(self):
227228
"download_and_extract_archive",
228229
}
229230

230-
def inject_fake_data(self, root: str, config: Dict[str, Any]) -> Dict[str, Any]:
231-
"""Inject fake data into the root of the dataset.
231+
def inject_fake_data(
232+
self, tmpdir: str, config: Dict[str, Any]
233+
) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]:
234+
"""Inject fake data for dataset into a temporary directory.
232235
233236
Args:
234-
root (str): Root of the dataset.
237+
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
238+
to be created and in turn also for the fake data injected here.
235239
config (Dict[str, Any]): Configuration that will be used to create the dataset.
236240
237-
Returns:
238-
info (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
239-
``"num_examples"`` that corresponds to the length of the dataset to be created.
241+
Needs to return one of the following:
242+
243+
1. (int): Number of examples in the dataset to be created,
244+
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
245+
``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
246+
3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
247+
the dataset constructor. The second element corresponds to cases 1. and 2.
248+
249+
If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
250+
constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
251+
values you need to explicitly pass them as explained in case 3.
240252
"""
241253
raise NotImplementedError("You need to provide fake data in order for the tests to run.")
242254

@@ -274,17 +286,38 @@ def create_dataset(
274286
if disable_download_extract is None:
275287
disable_download_extract = inject_fake_data
276288

277-
with get_tmp_dir() as root:
278-
info = self.inject_fake_data(root, config) if inject_fake_data else None
279-
if info is None or "num_examples" not in info:
289+
with get_tmp_dir() as tmpdir:
290+
output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None
291+
if output is None:
292+
raise UsageError(
293+
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
294+
"examples for the current configuration."
295+
)
296+
297+
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
298+
args, info = output
299+
else:
300+
args = (tmpdir,)
301+
info = output
302+
303+
if isinstance(info, int):
304+
info = dict(num_examples=info)
305+
elif isinstance(info, dict):
306+
if "num_examples" not in info:
307+
raise UsageError(
308+
"The information dictionary returned by the method 'inject_fake_data' must contain a "
309+
"'num_examples' field that holds the number of examples for the current configuration."
310+
)
311+
else:
280312
raise UsageError(
281-
"The method 'inject_fake_data' needs to return a dictionary that contains at least a "
282-
"'num_examples' field."
313+
f"The additional information returned by the method 'inject_fake_data' must be either an integer "
314+
f"indicating the number of examples for the current configuration or a dictionary with the the "
315+
f"same content. Got {type(info)} instead."
283316
)
284317

285318
cm = self._disable_download_extract if disable_download_extract else nullcontext
286319
with cm(special_kwargs), disable_console_output():
287-
dataset = self.DATASET_CLASS(root, **config, **special_kwargs)
320+
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
288321

289322
yield dataset, info
290323

test/test_datasets.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -473,21 +473,21 @@ def test_repr_smoke(self):
473473
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
474474
DATASET_CLASS = datasets.Caltech256
475475

476-
def inject_fake_data(self, root, config):
477-
root = pathlib.Path(root) / "caltech256" / "256_ObjectCategories"
476+
def inject_fake_data(self, tmpdir, config):
477+
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
478478

479479
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
480480
num_images_per_category = 2
481481

482482
for idx, category in categories:
483483
datasets_utils.create_image_folder(
484-
root,
484+
tmpdir,
485485
name=f"{idx:03d}.{category}",
486-
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx:04d}.jpg",
486+
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
487487
num_examples=num_images_per_category,
488488
)
489489

490-
return dict(num_examples=num_images_per_category * len(categories))
490+
return num_images_per_category * len(categories)
491491

492492

493493
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
@@ -504,15 +504,15 @@ class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
504504
categories_key="label_names",
505505
)
506506

507-
def inject_fake_data(self, root, config):
508-
root = pathlib.Path(root) / self._VERSION_CONFIG["base_folder"]
509-
os.makedirs(root)
507+
def inject_fake_data(self, tmpdir, config):
508+
tmpdir = pathlib.Path(tmpdir) / self._VERSION_CONFIG["base_folder"]
509+
os.makedirs(tmpdir)
510510

511511
num_images_per_file = 1
512512
for name in itertools.chain(self._VERSION_CONFIG["train_files"], self._VERSION_CONFIG["test_files"]):
513-
self._create_batch_file(root, name, num_images_per_file)
513+
self._create_batch_file(tmpdir, name, num_images_per_file)
514514

515-
categories = self._create_meta_file(root)
515+
categories = self._create_meta_file(tmpdir)
516516

517517
return dict(
518518
num_examples=num_images_per_file

0 commit comments

Comments
 (0)