7
7
import torch
8
8
import codecs
9
9
import string
10
- import gzip
11
- import lzma
12
- from typing import Any , Callable , Dict , IO , List , Optional , Tuple , Union
10
+ from typing import Any , Callable , Dict , List , Optional , Tuple
13
11
from urllib .error import URLError
14
- from .utils import download_url , download_and_extract_archive , extract_archive , \
15
- verify_str_arg
12
+ from .utils import download_and_extract_archive , extract_archive , verify_str_arg , check_integrity
13
+ import shutil
16
14
17
15
18
16
class MNIST (VisionDataset ):
@@ -81,18 +79,42 @@ def __init__(
81
79
target_transform = target_transform )
82
80
self .train = train # training set or test set
83
81
82
+ if self ._check_legacy_exist ():
83
+ self .data , self .targets = self ._load_legacy_data ()
84
+ return
85
+
84
86
if download :
85
87
self .download ()
86
88
87
89
if not self ._check_exists ():
88
90
raise RuntimeError ('Dataset not found.' +
89
91
' You can use download=True to download it' )
90
92
91
- if self .train :
92
- data_file = self .training_file
93
- else :
94
- data_file = self .test_file
95
- self .data , self .targets = torch .load (os .path .join (self .processed_folder , data_file ))
93
+ self .data , self .targets = self ._load_data ()
94
+
95
+ def _check_legacy_exist (self ):
96
+ processed_folder_exists = os .path .exists (self .processed_folder )
97
+ if not processed_folder_exists :
98
+ return False
99
+
100
+ return all (
101
+ check_integrity (os .path .join (self .processed_folder , file )) for file in (self .training_file , self .test_file )
102
+ )
103
+
104
+ def _load_legacy_data (self ):
105
+ # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
106
+ # directly.
107
+ data_file = self .training_file if self .train else self .test_file
108
+ return torch .load (os .path .join (self .processed_folder , data_file ))
109
+
110
+ def _load_data (self ):
111
+ image_file = f"{ 'train' if self .train else 't10k' } -images-idx3-ubyte"
112
+ data = read_image_file (os .path .join (self .raw_folder , image_file ))
113
+
114
+ label_file = f"{ 'train' if self .train else 't10k' } -labels-idx1-ubyte"
115
+ targets = read_label_file (os .path .join (self .raw_folder , label_file ))
116
+
117
+ return data , targets
96
118
97
119
def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
98
120
"""
@@ -132,19 +154,18 @@ def class_to_idx(self) -> Dict[str, int]:
132
154
return {_class : i for i , _class in enumerate (self .classes )}
133
155
134
156
def _check_exists (self ) -> bool :
135
- return ( os . path . exists ( os . path . join ( self . processed_folder ,
136
- self .training_file )) and
137
- os . path . exists ( os . path . join ( self .processed_folder ,
138
- self . test_file )) )
157
+ return all (
158
+ check_integrity ( os . path . join ( self .raw_folder , os . path . splitext ( os . path . basename ( url ))[ 0 ]))
159
+ for url , _ in self .resources
160
+ )
139
161
140
162
def download (self ) -> None :
141
- """Download the MNIST data if it doesn't exist in processed_folder already."""
163
+ """Download the MNIST data if it doesn't exist already."""
142
164
143
165
if self ._check_exists ():
144
166
return
145
167
146
168
os .makedirs (self .raw_folder , exist_ok = True )
147
- os .makedirs (self .processed_folder , exist_ok = True )
148
169
149
170
# download files
150
171
for filename , md5 in self .resources :
@@ -168,24 +189,6 @@ def download(self) -> None:
168
189
else :
169
190
raise RuntimeError ("Error downloading {}" .format (filename ))
170
191
171
- # process and save as torch files
172
- print ('Processing...' )
173
-
174
- training_set = (
175
- read_image_file (os .path .join (self .raw_folder , 'train-images-idx3-ubyte' )),
176
- read_label_file (os .path .join (self .raw_folder , 'train-labels-idx1-ubyte' ))
177
- )
178
- test_set = (
179
- read_image_file (os .path .join (self .raw_folder , 't10k-images-idx3-ubyte' )),
180
- read_label_file (os .path .join (self .raw_folder , 't10k-labels-idx1-ubyte' ))
181
- )
182
- with open (os .path .join (self .processed_folder , self .training_file ), 'wb' ) as f :
183
- torch .save (training_set , f )
184
- with open (os .path .join (self .processed_folder , self .test_file ), 'wb' ) as f :
185
- torch .save (test_set , f )
186
-
187
- print ('Done!' )
188
-
189
192
def extra_repr (self ) -> str :
190
193
return "Split: {}" .format ("Train" if self .train is True else "Test" )
191
194
@@ -298,44 +301,39 @@ def _training_file(split) -> str:
298
301
def _test_file (split ) -> str :
299
302
return 'test_{}.pt' .format (split )
300
303
304
+ @property
305
+ def _file_prefix (self ) -> str :
306
+ return f"emnist-{ self .split } -{ 'train' if self .train else 'test' } "
307
+
308
+ @property
309
+ def images_file (self ) -> str :
310
+ return os .path .join (self .raw_folder , f"{ self ._file_prefix } -images-idx3-ubyte" )
311
+
312
+ @property
313
+ def labels_file (self ) -> str :
314
+ return os .path .join (self .raw_folder , f"{ self ._file_prefix } -labels-idx1-ubyte" )
315
+
316
+ def _load_data (self ):
317
+ return read_image_file (self .images_file ), read_label_file (self .labels_file )
318
+
319
+ def _check_exists (self ) -> bool :
320
+ return all (check_integrity (file ) for file in (self .images_file , self .labels_file ))
321
+
301
322
def download (self ) -> None :
302
- """Download the EMNIST data if it doesn't exist in processed_folder already."""
303
- import shutil
323
+ """Download the EMNIST data if it doesn't exist already."""
304
324
305
325
if self ._check_exists ():
306
326
return
307
327
308
328
os .makedirs (self .raw_folder , exist_ok = True )
309
- os .makedirs (self .processed_folder , exist_ok = True )
310
329
311
- # download files
312
- print ('Downloading and extracting zip archive' )
313
- download_and_extract_archive (self .url , download_root = self .raw_folder , filename = "emnist.zip" ,
314
- remove_finished = True , md5 = self .md5 )
330
+ download_and_extract_archive (self .url , download_root = self .raw_folder , md5 = self .md5 )
315
331
gzip_folder = os .path .join (self .raw_folder , 'gzip' )
316
332
for gzip_file in os .listdir (gzip_folder ):
317
333
if gzip_file .endswith ('.gz' ):
318
- extract_archive (os .path .join (gzip_folder , gzip_file ), gzip_folder )
319
-
320
- # process and save as torch files
321
- for split in self .splits :
322
- print ('Processing ' + split )
323
- training_set = (
324
- read_image_file (os .path .join (gzip_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
325
- read_label_file (os .path .join (gzip_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
326
- )
327
- test_set = (
328
- read_image_file (os .path .join (gzip_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
329
- read_label_file (os .path .join (gzip_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
330
- )
331
- with open (os .path .join (self .processed_folder , self ._training_file (split )), 'wb' ) as f :
332
- torch .save (training_set , f )
333
- with open (os .path .join (self .processed_folder , self ._test_file (split )), 'wb' ) as f :
334
- torch .save (test_set , f )
334
+ extract_archive (os .path .join (gzip_folder , gzip_file ), self .raw_folder )
335
335
shutil .rmtree (gzip_folder )
336
336
337
- print ('Done!' )
338
-
339
337
340
338
class QMNIST (MNIST ):
341
339
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
@@ -404,40 +402,51 @@ def __init__(
404
402
self .test_file = self .data_file
405
403
super (QMNIST , self ).__init__ (root , train , ** kwargs )
406
404
405
+ @property
406
+ def images_file (self ) -> str :
407
+ (url , _ ), _ = self .resources [self .subsets [self .what ]]
408
+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
409
+
410
+ @property
411
+ def labels_file (self ) -> str :
412
+ _ , (url , _ ) = self .resources [self .subsets [self .what ]]
413
+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
414
+
415
+ def _check_exists (self ) -> bool :
416
+ return all (check_integrity (file ) for file in (self .images_file , self .labels_file ))
417
+
418
+ def _load_data (self ):
419
+ data = read_sn3_pascalvincent_tensor (self .images_file )
420
+ assert (data .dtype == torch .uint8 )
421
+ assert (data .ndimension () == 3 )
422
+
423
+ targets = read_sn3_pascalvincent_tensor (self .labels_file ).long ()
424
+ assert (targets .ndimension () == 2 )
425
+
426
+ if self .what == 'test10k' :
427
+ data = data [0 :10000 , :, :].clone ()
428
+ targets = targets [0 :10000 , :].clone ()
429
+ elif self .what == 'test50k' :
430
+ data = data [10000 :, :, :].clone ()
431
+ targets = targets [10000 :, :].clone ()
432
+
433
+ return data , targets
434
+
407
435
def download (self ) -> None :
408
- """Download the QMNIST data if it doesn't exist in processed_folder already.
436
+ """Download the QMNIST data if it doesn't exist already.
409
437
Note that we only download what has been asked for (argument 'what').
410
438
"""
411
439
if self ._check_exists ():
412
440
return
441
+
413
442
os .makedirs (self .raw_folder , exist_ok = True )
414
- os .makedirs (self .processed_folder , exist_ok = True )
415
443
split = self .resources [self .subsets [self .what ]]
416
- files = []
417
444
418
- # download data files if not already there
419
445
for url , md5 in split :
420
446
filename = url .rpartition ('/' )[2 ]
421
447
file_path = os .path .join (self .raw_folder , filename )
422
448
if not os .path .isfile (file_path ):
423
- download_url (url , root = self .raw_folder , filename = filename , md5 = md5 )
424
- files .append (file_path )
425
-
426
- # process and save as torch files
427
- print ('Processing...' )
428
- data = read_sn3_pascalvincent_tensor (files [0 ])
429
- assert (data .dtype == torch .uint8 )
430
- assert (data .ndimension () == 3 )
431
- targets = read_sn3_pascalvincent_tensor (files [1 ]).long ()
432
- assert (targets .ndimension () == 2 )
433
- if self .what == 'test10k' :
434
- data = data [0 :10000 , :, :].clone ()
435
- targets = targets [0 :10000 , :].clone ()
436
- if self .what == 'test50k' :
437
- data = data [10000 :, :, :].clone ()
438
- targets = targets [10000 :, :].clone ()
439
- with open (os .path .join (self .processed_folder , self .data_file ), 'wb' ) as f :
440
- torch .save ((data , targets ), f )
449
+ download_and_extract_archive (url , self .raw_folder , filename = filename , md5 = md5 )
441
450
442
451
def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
443
452
# redefined to handle the compat flag
@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
459
468
return int (codecs .encode (b , 'hex' ), 16 )
460
469
461
470
462
- def open_maybe_compressed_file (path : Union [str , IO ]) -> Union [IO , gzip .GzipFile ]:
463
- """Return a file object that possibly decompresses 'path' on the fly.
464
- Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
465
- """
466
- if not isinstance (path , torch ._six .string_classes ):
467
- return path
468
- if path .endswith ('.gz' ):
469
- return gzip .open (path , 'rb' )
470
- if path .endswith ('.xz' ):
471
- return lzma .open (path , 'rb' )
472
- return open (path , 'rb' )
473
-
474
-
475
471
SN3_PASCALVINCENT_TYPEMAP = {
476
472
8 : (torch .uint8 , np .uint8 , np .uint8 ),
477
473
9 : (torch .int8 , np .int8 , np .int8 ),
@@ -482,12 +478,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]
482
478
}
483
479
484
480
485
- def read_sn3_pascalvincent_tensor (path : Union [ str , IO ] , strict : bool = True ) -> torch .Tensor :
481
+ def read_sn3_pascalvincent_tensor (path : str , strict : bool = True ) -> torch .Tensor :
486
482
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
487
483
Argument may be a filename, compressed filename, or file object.
488
484
"""
489
485
# read
490
- with open_maybe_compressed_file (path ) as f :
486
+ with open (path , "rb" ) as f :
491
487
data = f .read ()
492
488
# parse
493
489
magic = get_int (data [0 :4 ])
@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->
503
499
504
500
505
501
def read_label_file (path : str ) -> torch .Tensor :
506
- with open (path , 'rb' ) as f :
507
- x = read_sn3_pascalvincent_tensor (f , strict = False )
502
+ x = read_sn3_pascalvincent_tensor (path , strict = False )
508
503
assert (x .dtype == torch .uint8 )
509
504
assert (x .ndimension () == 1 )
510
505
return x .long ()
511
506
512
507
513
508
def read_image_file (path : str ) -> torch .Tensor :
514
- with open (path , 'rb' ) as f :
515
- x = read_sn3_pascalvincent_tensor (f , strict = False )
509
+ x = read_sn3_pascalvincent_tensor (path , strict = False )
516
510
assert (x .dtype == torch .uint8 )
517
511
assert (x .ndimension () == 3 )
518
512
return x
0 commit comments