6
6
import itertools
7
7
import os
8
8
import pathlib
9
+ import random
10
+ import string
9
11
import unittest
10
12
import unittest .mock
11
13
from typing import Any , Callable , Dict , Iterator , List , Optional , Sequence , Tuple , Union
32
34
"create_image_folder" ,
33
35
"create_video_file" ,
34
36
"create_video_folder" ,
37
+ "create_random_string" ,
35
38
]
36
39
37
40
@@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
93
96
return outer_wrapper
94
97
95
98
96
- # As of Python 3.7 this is provided by contextlib
97
- # https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
98
- # TODO: If the minimum Python requirement is >= 3.7, replace this
99
- @contextlib .contextmanager
100
- def nullcontext (enter_result = None ):
101
- yield enter_result
102
-
103
-
104
99
def test_all_configs (test ):
105
100
"""Decorator to run test against all configurations.
106
101
@@ -116,7 +111,7 @@ def test_foo(self, config):
116
111
117
112
@functools .wraps (test )
118
113
def wrapper (self ):
119
- for config in self .CONFIGS :
114
+ for config in self .CONFIGS or ( self . _DEFAULT_CONFIG ,) :
120
115
with self .subTest (** config ):
121
116
test (self , config )
122
117
@@ -207,6 +202,8 @@ def test_baz(self):
207
202
CONFIGS = None
208
203
REQUIRED_PACKAGES = None
209
204
205
+ _DEFAULT_CONFIG = None
206
+
210
207
_TRANSFORM_KWARGS = {
211
208
"transform" ,
212
209
"target_transform" ,
@@ -268,7 +265,7 @@ def create_dataset(
268
265
self ,
269
266
config : Optional [Dict [str , Any ]] = None ,
270
267
inject_fake_data : bool = True ,
271
- disable_download_extract : Optional [bool ] = None ,
268
+ patch_checks : Optional [bool ] = None ,
272
269
** kwargs : Any ,
273
270
) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
274
271
r"""Create the dataset in a temporary directory.
@@ -278,8 +275,8 @@ def create_dataset(
278
275
default configuration is used.
279
276
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
280
277
creating the dataset.
281
- disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
282
- the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
278
+ patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
279
+ omitted defaults to the same value as ``inject_fake_data``.
283
280
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
284
281
overlap with ``config``.
285
282
@@ -288,43 +285,28 @@ def create_dataset(
288
285
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
289
286
for details.
290
287
"""
291
- if config is None :
292
- config = self .CONFIGS [0 ].copy ()
288
+ default_config = self ._DEFAULT_CONFIG .copy ()
289
+ if config is not None :
290
+ default_config .update (config )
291
+ config = default_config
292
+
293
+ if patch_checks is None :
294
+ patch_checks = inject_fake_data
293
295
294
296
special_kwargs , other_kwargs = self ._split_kwargs (kwargs )
297
+ if "download" in self ._HAS_SPECIAL_KWARG :
298
+ special_kwargs ["download" ] = False
295
299
config .update (other_kwargs )
296
300
297
- if disable_download_extract is None :
298
- disable_download_extract = inject_fake_data
301
+ patchers = self ._patch_download_extract ()
302
+ if patch_checks :
303
+ patchers .update (self ._patch_checks ())
299
304
300
305
with get_tmp_dir () as tmpdir :
301
306
args = self .dataset_args (tmpdir , config )
307
+ info = self ._inject_fake_data (tmpdir , config ) if inject_fake_data else None
302
308
303
- if inject_fake_data :
304
- info = self .inject_fake_data (tmpdir , config )
305
- if info is None :
306
- raise UsageError (
307
- "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
308
- "examples for the current configuration."
309
- )
310
- elif isinstance (info , int ):
311
- info = dict (num_examples = info )
312
- elif not isinstance (info , dict ):
313
- raise UsageError (
314
- f"The additional information returned by the method 'inject_fake_data' must be either an "
315
- f"integer indicating the number of examples for the current configuration or a dictionary with "
316
- f"the same content. Got { type (info )} instead."
317
- )
318
- elif "num_examples" not in info :
319
- raise UsageError (
320
- "The information dictionary returned by the method 'inject_fake_data' must contain a "
321
- "'num_examples' field that holds the number of examples for the current configuration."
322
- )
323
- else :
324
- info = None
325
-
326
- cm = self ._disable_download_extract if disable_download_extract else nullcontext
327
- with cm (special_kwargs ), disable_console_output ():
309
+ with self ._maybe_apply_patches (patchers ), disable_console_output ():
328
310
dataset = self .DATASET_CLASS (* args , ** config , ** special_kwargs )
329
311
330
312
yield dataset , info
@@ -352,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
352
334
@classmethod
353
335
def _populate_private_class_attributes (cls ):
354
336
argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
337
+
338
+ cls ._DEFAULT_CONFIG = {
339
+ kwarg : default
340
+ for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
341
+ if kwarg not in cls ._SPECIAL_KWARGS
342
+ }
343
+
355
344
cls ._HAS_SPECIAL_KWARG = {name for name in cls ._SPECIAL_KWARGS if name in argspec .args }
356
345
357
346
@classmethod
358
347
def _process_optional_public_class_attributes (cls ):
359
- argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
360
- if cls .CONFIGS is None :
361
- config = {
362
- kwarg : default
363
- for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
364
- if kwarg not in cls ._SPECIAL_KWARGS
365
- }
366
- cls .CONFIGS = (config ,)
367
-
368
348
if cls .REQUIRED_PACKAGES is not None :
369
349
try :
370
350
for pkg in cls .REQUIRED_PACKAGES :
@@ -380,28 +360,44 @@ def _split_kwargs(self, kwargs):
380
360
other_kwargs = {key : special_kwargs .pop (key ) for key in set (special_kwargs .keys ()) - self ._SPECIAL_KWARGS }
381
361
return special_kwargs , other_kwargs
382
362
383
- @contextlib .contextmanager
384
- def _disable_download_extract (self , special_kwargs ):
385
- inject_download_kwarg = "download" in self ._HAS_SPECIAL_KWARG and "download" not in special_kwargs
386
- if inject_download_kwarg :
387
- special_kwargs ["download" ] = False
363
+ def _inject_fake_data (self , tmpdir , config ):
364
+ info = self .inject_fake_data (tmpdir , config )
365
+ if info is None :
366
+ raise UsageError (
367
+ "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
368
+ "examples for the current configuration."
369
+ )
370
+ elif isinstance (info , int ):
371
+ info = dict (num_examples = info )
372
+ elif not isinstance (info , dict ):
373
+ raise UsageError (
374
+ f"The additional information returned by the method 'inject_fake_data' must be either an "
375
+ f"integer indicating the number of examples for the current configuration or a dictionary with "
376
+ f"the same content. Got { type (info )} instead."
377
+ )
378
+ elif "num_examples" not in info :
379
+ raise UsageError (
380
+ "The information dictionary returned by the method 'inject_fake_data' must contain a "
381
+ "'num_examples' field that holds the number of examples for the current configuration."
382
+ )
383
+ return info
384
+
385
+ def _patch_download_extract (self ):
386
+ module = inspect .getmodule (self .DATASET_CLASS ).__name__
387
+ return {unittest .mock .patch (f"{ module } .{ function } " ) for function in self ._DOWNLOAD_EXTRACT_FUNCTIONS }
388
388
389
+ def _patch_checks (self ):
389
390
module = inspect .getmodule (self .DATASET_CLASS ).__name__
391
+ return {unittest .mock .patch (f"{ module } .{ function } " , return_value = True ) for function in self ._CHECK_FUNCTIONS }
392
+
393
+ @contextlib .contextmanager
394
+ def _maybe_apply_patches (self , patchers ):
390
395
with contextlib .ExitStack () as stack :
391
396
mocks = {}
392
- for function , kwargs in itertools .chain (
393
- zip (self ._CHECK_FUNCTIONS , [dict (return_value = True )] * len (self ._CHECK_FUNCTIONS )),
394
- zip (self ._DOWNLOAD_EXTRACT_FUNCTIONS , [dict ()] * len (self ._DOWNLOAD_EXTRACT_FUNCTIONS )),
395
- ):
397
+ for patcher in patchers :
396
398
with contextlib .suppress (AttributeError ):
397
- patcher = unittest .mock .patch (f"{ module } .{ function } " , ** kwargs )
398
- mocks [function ] = stack .enter_context (patcher )
399
-
400
- try :
401
- yield mocks
402
- finally :
403
- if inject_download_kwarg :
404
- del special_kwargs ["download" ]
399
+ mocks [patcher .target ] = stack .enter_context (patcher )
400
+ yield mocks
405
401
406
402
def test_not_found_or_corrupted (self ):
407
403
with self .assertRaises ((FileNotFoundError , RuntimeError )):
@@ -469,13 +465,13 @@ def create_dataset(
469
465
self ,
470
466
config : Optional [Dict [str , Any ]] = None ,
471
467
inject_fake_data : bool = True ,
472
- disable_download_extract : Optional [bool ] = None ,
468
+ patch_checks : Optional [bool ] = None ,
473
469
** kwargs : Any ,
474
470
) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
475
471
with super ().create_dataset (
476
472
config = config ,
477
473
inject_fake_data = inject_fake_data ,
478
- disable_download_extract = disable_download_extract ,
474
+ patch_checks = patch_checks ,
479
475
** kwargs ,
480
476
) as (dataset , info ):
481
477
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
@@ -572,7 +568,7 @@ def create_image_file(
572
568
573
569
image = create_image_or_video_tensor (size )
574
570
file = pathlib .Path (root ) / name
575
- PIL .Image .fromarray (image .permute (2 , 1 , 0 ).numpy ()).save (file )
571
+ PIL .Image .fromarray (image .permute (2 , 1 , 0 ).numpy ()).save (file , ** kwargs )
576
572
return file
577
573
578
574
@@ -708,6 +704,21 @@ def size(idx):
708
704
os .makedirs (root )
709
705
710
706
return [
711
- create_video_file (root , file_name_fn (idx ), size = size (idx ) if callable (size ) else size )
707
+ create_video_file (root , file_name_fn (idx ), size = size (idx ) if callable (size ) else size , ** kwargs )
712
708
for idx in range (num_examples )
713
709
]
710
+
711
+
712
+ def create_random_string (length : int , * digits : str ) -> str :
713
+ """Create a random string.
714
+
715
+ Args:
716
+ length (int): Number of characters in the generated string.
717
+ *characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
718
+ """
719
+ if not digits :
720
+ digits = string .ascii_lowercase
721
+ else :
722
+ digits = "" .join (itertools .chain (* digits ))
723
+
724
+ return "" .join (random .choice (digits ) for _ in range (length ))
0 commit comments