13
13
import unittest .mock
14
14
import hashlib
15
15
from distutils import dir_util
16
+ import re
17
+
18
+
19
+ def mock_class_attribute (stack , target , new ):
20
+ mock = unittest .mock .patch (target , new_callable = unittest .mock .PropertyMock , return_value = new )
21
+ stack .enter_context (mock )
22
+ return mock
23
+
24
+
25
+ def compute_md5 (file ):
26
+ with open (file , "rb" ) as fh :
27
+ return hashlib .md5 (fh .read ()).hexdigest ()
28
+
29
+
30
+ def make_tar (root , name , * files , compression = None ):
31
+ ext = ".tar"
32
+ mode = "w"
33
+ if compression is not None :
34
+ ext = f"{ ext } .{ compression } "
35
+ mode = f"{ mode } :{ compression } "
36
+
37
+ name = os .path .splitext (name )[0 ] + ext
38
+ archive = os .path .join (root , name )
39
+
40
+ with tarfile .open (archive , mode ) as fh :
41
+ for file in files :
42
+ fh .add (os .path .join (root , file ), arcname = file )
43
+
44
+ return name , compute_md5 (archive )
45
+
46
+
47
+ def clean_dir (root , * keep ):
48
+ pattern = re .compile (f"({ f')|(' .join (keep )} )" )
49
+ for file_or_dir in os .listdir (root ):
50
+ if pattern .search (file_or_dir ):
51
+ continue
52
+
53
+ file_or_dir = os .path .join (root , file_or_dir )
54
+ if os .path .isfile (file_or_dir ):
55
+ os .remove (file_or_dir )
56
+ else :
57
+ dir_util .remove_tree (file_or_dir )
16
58
17
59
18
60
@contextlib .contextmanager
@@ -385,7 +427,7 @@ def ucf101_root():
385
427
386
428
387
429
@contextlib .contextmanager
388
- def places365_root (split = "train-standard" , small = False , extract_images = True ):
430
+ def places365_root (split = "train-standard" , small = False ):
389
431
VARIANTS = {
390
432
"train-standard" : "standard" ,
391
433
"train-challenge" : "challenge" ,
@@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True):
425
467
def mock_target (attr , partial = "torchvision.datasets.places365.Places365" ):
426
468
return f"{ partial } .{ attr } "
427
469
428
- def mock_class_attribute (stack , attr , new ):
429
- mock = unittest .mock .patch (mock_target (attr ), new_callable = unittest .mock .PropertyMock , return_value = new )
430
- stack .enter_context (mock )
431
- return mock
432
-
433
- def compute_md5 (file ):
434
- with open (file , "rb" ) as fh :
435
- return hashlib .md5 (fh .read ()).hexdigest ()
436
-
437
470
def make_txt (root , name , seq ):
438
471
file = os .path .join (root , name )
439
472
with open (file , "w" ) as fh :
@@ -451,37 +484,20 @@ def make_image(file, size):
451
484
os .makedirs (os .path .dirname (file ), exist_ok = True )
452
485
PIL .Image .fromarray (np .zeros ((* size , 3 ), dtype = np .uint8 )).save (file )
453
486
454
- def make_tar (root , name , * files , remove_files = True ):
455
- name = f"{ os .path .splitext (name )[0 ]} .tar"
456
- archive = os .path .join (root , name )
457
-
458
- with tarfile .open (archive , "w" ) as fh :
459
- for file in files :
460
- fh .add (os .path .join (root , file ), arcname = file )
461
-
462
- if remove_files :
463
- for file in [os .path .join (root , file ) for file in files ]:
464
- if os .path .isdir (file ):
465
- dir_util .remove_tree (file )
466
- else :
467
- os .remove (file )
468
-
469
- return name , compute_md5 (archive )
470
-
471
487
def make_devkit_archive (stack , root , split ):
472
488
archive = DEVKITS [split ]
473
489
files = []
474
490
475
491
meta = make_categories_txt (root , CATEGORIES )
476
- mock_class_attribute (stack , "_CATEGORIES_META" , meta )
492
+ mock_class_attribute (stack , mock_target ( "_CATEGORIES_META" ) , meta )
477
493
files .append (meta [0 ])
478
494
479
495
meta = {split : make_file_list_txt (root , FILE_LISTS [split ])}
480
- mock_class_attribute (stack , "_FILE_LIST_META" , meta )
496
+ mock_class_attribute (stack , mock_target ( "_FILE_LIST_META" ) , meta )
481
497
files .extend ([item [0 ] for item in meta .values ()])
482
498
483
499
meta = {VARIANTS [split ]: make_tar (root , archive , * files )}
484
- mock_class_attribute (stack , "_DEVKIT_META" , meta )
500
+ mock_class_attribute (stack , mock_target ( "_DEVKIT_META" ) , meta )
485
501
486
502
def make_images_archive (stack , root , split , small ):
487
503
archive , folder_default , folder_renamed = IMAGES [(split , small )]
@@ -493,20 +509,97 @@ def make_images_archive(stack, root, split, small):
493
509
make_image (os .path .join (root , folder_default , image ), image_size )
494
510
495
511
meta = {(split , small ): make_tar (root , archive , folder_default )}
496
- mock_class_attribute (stack , "_IMAGES_META" , meta )
512
+ mock_class_attribute (stack , mock_target ( "_IMAGES_META" ) , meta )
497
513
498
514
return [(os .path .join (root , folder_renamed , image ), idx ) for image , idx in zip (images , idcs )]
499
515
500
516
with contextlib .ExitStack () as stack , get_tmp_dir () as root :
501
517
make_devkit_archive (stack , root , split )
502
518
class_to_idx = dict (CATEGORIES_CONTENT )
503
519
classes = list (class_to_idx .keys ())
520
+
504
521
data = {"class_to_idx" : class_to_idx , "classes" : classes }
522
+ data ["imgs" ] = make_images_archive (stack , root , split , small )
505
523
506
- if extract_images :
507
- data ["imgs" ] = make_images_archive (stack , root , split , small )
508
- else :
509
- stack .enter_context (unittest .mock .patch (mock_target ("download_images" )))
510
- data ["imgs" ] = None
524
+ clean_dir (root , ".tar$" )
525
+
526
+ yield root , data
527
+
528
+
529
+ @contextlib .contextmanager
530
+ def stl10_root (_extracted = False ):
531
+ CLASS_NAMES = ("airplane" , "bird" )
532
+ ARCHIVE_NAME = "stl10_binary"
533
+ NUM_FOLDS = 10
534
+
535
+ def mock_target (attr , partial = "torchvision.datasets.stl10.STL10" ):
536
+ return f"{ partial } .{ attr } "
537
+
538
+ def make_binary_file (num_elements , root , name ):
539
+ file = os .path .join (root , name )
540
+ np .zeros (num_elements , dtype = np .uint8 ).tofile (file )
541
+ return name , compute_md5 (file )
542
+
543
+ def make_image_file (num_images , root , name , num_channels = 3 , height = 96 , width = 96 ):
544
+ return make_binary_file (num_images * num_channels * height * width , root , name )
545
+
546
+ def make_label_file (num_images , root , name ):
547
+ return make_binary_file (num_images , root , name )
548
+
549
+ def make_class_names_file (root , name = "class_names.txt" ):
550
+ with open (os .path .join (root , name ), "w" ) as fh :
551
+ for name in CLASS_NAMES :
552
+ fh .write (f"{ name } \n " )
553
+
554
+ def make_fold_indices_file (root ):
555
+ offset = 0
556
+ with open (os .path .join (root , "fold_indices.txt" ), "w" ) as fh :
557
+ for fold in range (NUM_FOLDS ):
558
+ line = " " .join ([str (idx ) for idx in range (offset , offset + fold + 1 )])
559
+ fh .write (f"{ line } \n " )
560
+ offset += fold + 1
561
+
562
+ return tuple (range (1 , NUM_FOLDS + 1 ))
563
+
564
+ def make_train_files (stack , root , num_unlabeled_images = 1 ):
565
+ num_images_in_fold = make_fold_indices_file (root )
566
+ num_train_images = sum (num_images_in_fold )
567
+
568
+ train_list = [
569
+ list (make_image_file (num_train_images , root , "train_X.bin" )),
570
+ list (make_label_file (num_train_images , root , "train_y.bin" )),
571
+ list (make_image_file (1 , root , "unlabeled_X.bin" ))
572
+ ]
573
+ mock_class_attribute (stack , target = mock_target ("train_list" ), new = train_list )
574
+
575
+ return num_images_in_fold , dict (train = num_train_images , unlabeled = num_unlabeled_images )
576
+
577
+ def make_test_files (stack , root , num_images = 2 ):
578
+ test_list = [
579
+ list (make_image_file (num_images , root , "test_X.bin" )),
580
+ list (make_label_file (num_images , root , "test_y.bin" )),
581
+ ]
582
+ mock_class_attribute (stack , target = mock_target ("test_list" ), new = test_list )
583
+
584
+ return dict (test = num_images )
585
+
586
+ def make_archive (stack , root , name ):
587
+ archive , md5 = make_tar (root , name , name , compression = "gz" )
588
+ mock_class_attribute (stack , target = mock_target ("tgz_md5" ), new = md5 )
589
+ return archive
590
+
591
+ with contextlib .ExitStack () as stack , get_tmp_dir () as root :
592
+ archive_folder = os .path .join (root , ARCHIVE_NAME )
593
+ os .mkdir (archive_folder )
594
+
595
+ num_images_in_folds , num_images_in_split = make_train_files (stack , archive_folder )
596
+ num_images_in_split .update (make_test_files (stack , archive_folder ))
597
+
598
+ make_class_names_file (archive_folder )
599
+
600
+ archive = make_archive (stack , root , ARCHIVE_NAME )
601
+
602
+ dir_util .remove_tree (archive_folder )
603
+ data = dict (num_images_in_folds = num_images_in_folds , num_images_in_split = num_images_in_split , archive = archive )
511
604
512
605
yield root , data
0 commit comments