@@ -37,27 +37,28 @@ def __init__(
37
37
38
38
def __iter__ (self ) -> Iterator [torch .Tensor ]:
39
39
for _ , file in self .datapipe :
40
- read = functools .partial (fromfile , file , byte_order = "big" )
40
+ try :
41
+ read = functools .partial (fromfile , file , byte_order = "big" )
41
42
42
- magic = int (read (dtype = torch .int32 , count = 1 ))
43
- dtype = self ._DTYPE_MAP [magic // 256 ]
44
- ndim = magic % 256 - 1
43
+ magic = int (read (dtype = torch .int32 , count = 1 ))
44
+ dtype = self ._DTYPE_MAP [magic // 256 ]
45
+ ndim = magic % 256 - 1
45
46
46
- num_samples = int (read (dtype = torch .int32 , count = 1 ))
47
- shape = cast (List [int ], read (dtype = torch .int32 , count = ndim ).tolist ()) if ndim else []
48
- count = prod (shape ) if shape else 1
47
+ num_samples = int (read (dtype = torch .int32 , count = 1 ))
48
+ shape = cast (List [int ], read (dtype = torch .int32 , count = ndim ).tolist ()) if ndim else []
49
+ count = prod (shape ) if shape else 1
49
50
50
- start = self .start or 0
51
- stop = min (self .stop , num_samples ) if self .stop else num_samples
51
+ start = self .start or 0
52
+ stop = min (self .stop , num_samples ) if self .stop else num_samples
52
53
53
- if start :
54
- num_bytes_per_value = (torch .finfo if dtype .is_floating_point else torch .iinfo )(dtype ).bits // 8
55
- file .seek (num_bytes_per_value * count * start , 1 )
54
+ if start :
55
+ num_bytes_per_value = (torch .finfo if dtype .is_floating_point else torch .iinfo )(dtype ).bits // 8
56
+ file .seek (num_bytes_per_value * count * start , 1 )
56
57
57
- for _ in range (stop - start ):
58
- yield read (dtype = dtype , count = count ).reshape (shape )
59
-
60
- file .close ()
58
+ for _ in range (stop - start ):
59
+ yield read (dtype = dtype , count = count ).reshape (shape )
60
+ finally :
61
+ file .close ()
61
62
62
63
63
64
class _MNISTBase (Dataset ):
0 commit comments