|
| 1 | +import collections.abc |
1 | 2 | import contextlib
|
2 | 3 | import functools
|
3 | 4 | import importlib
|
@@ -227,16 +228,27 @@ def test_baz(self):
|
227 | 228 | "download_and_extract_archive",
|
228 | 229 | }
|
229 | 230 |
|
230 |
| - def inject_fake_data(self, root: str, config: Dict[str, Any]) -> Dict[str, Any]: |
231 |
| - """Inject fake data into the root of the dataset. |
| 231 | + def inject_fake_data( |
| 232 | + self, tmpdir: str, config: Dict[str, Any] |
| 233 | + ) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]: |
| 234 | + """Inject fake data for dataset into a temporary directory. |
232 | 235 |
|
233 | 236 | Args:
|
234 |
| - root (str): Root of the dataset. |
| 237 | + tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset |
| 238 | + to be created and in turn also for the fake data injected here. |
235 | 239 | config (Dict[str, Any]): Configuration that will be used to create the dataset.
|
236 | 240 |
|
237 |
| - Returns: |
238 |
| - info (Dict[str, Any]): Additional information about the injected fake data. Must contain the field |
239 |
| - ``"num_examples"`` that corresponds to the length of the dataset to be created. |
| 241 | + Needs to return one of the following: |
| 242 | +
|
| 243 | + 1. (int): Number of examples in the dataset to be created, |
| 244 | + 2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field |
| 245 | + ``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or |
| 246 | + 3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to |
| 247 | + the dataset constructor. The second element corresponds to cases 1. and 2. |
| 248 | +
|
| 249 | + If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset |
| 250 | + constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default |
| 251 | + values you need to explicitly pass them as explained in case 3. |
240 | 252 | """
|
241 | 253 | raise NotImplementedError("You need to provide fake data in order for the tests to run.")
|
242 | 254 |
|
@@ -274,17 +286,38 @@ def create_dataset(
|
274 | 286 | if disable_download_extract is None:
|
275 | 287 | disable_download_extract = inject_fake_data
|
276 | 288 |
|
277 |
| - with get_tmp_dir() as root: |
278 |
| - info = self.inject_fake_data(root, config) if inject_fake_data else None |
279 |
| - if info is None or "num_examples" not in info: |
| 289 | + with get_tmp_dir() as tmpdir: |
| 290 | + output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None |
| 291 | + if output is None: |
| 292 | + raise UsageError( |
| 293 | + "The method 'inject_fake_data' needs to return at least an integer indicating the number of " |
| 294 | + "examples for the current configuration." |
| 295 | + ) |
| 296 | + |
| 297 | + if isinstance(output, collections.abc.Sequence) and len(output) == 2: |
| 298 | + args, info = output |
| 299 | + else: |
| 300 | + args = (tmpdir,) |
| 301 | + info = output |
| 302 | + |
| 303 | + if isinstance(info, int): |
| 304 | + info = dict(num_examples=info) |
| 305 | + elif isinstance(info, dict): |
| 306 | + if "num_examples" not in info: |
| 307 | + raise UsageError( |
| 308 | + "The information dictionary returned by the method 'inject_fake_data' must contain a " |
| 309 | + "'num_examples' field that holds the number of examples for the current configuration." |
| 310 | + ) |
| 311 | + else: |
280 | 312 | raise UsageError(
|
281 |
| - "The method 'inject_fake_data' needs to return a dictionary that contains at least a " |
282 |
| - "'num_examples' field." |
| 313 | + f"The additional information returned by the method 'inject_fake_data' must be either an integer " |
| 314 | + f"indicating the number of examples for the current configuration or a dictionary with the the " |
| 315 | + f"same content. Got {type(info)} instead." |
283 | 316 | )
|
284 | 317 |
|
285 | 318 | cm = self._disable_download_extract if disable_download_extract else nullcontext
|
286 | 319 | with cm(special_kwargs), disable_console_output():
|
287 |
| - dataset = self.DATASET_CLASS(root, **config, **special_kwargs) |
| 320 | + dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) |
288 | 321 |
|
289 | 322 | yield dataset, info
|
290 | 323 |
|
|
0 commit comments