1
1
import csv
2
- import os
3
- import os .path
2
+ import pathlib
4
3
from typing import Any , Callable , Optional , Tuple
5
4
6
5
import torch
@@ -38,7 +37,17 @@ def __init__(
38
37
self ._split = verify_str_arg (split , "split" , self ._RESOURCES .keys ())
39
38
super ().__init__ (root , transform = transform , target_transform = target_transform )
40
39
41
- with open (self ._verify_integrity (), "r" , newline = "" ) as file :
40
+ base_folder = pathlib .Path (self .root ) / "fer2013"
41
+ file_name , md5 = self ._RESOURCES [self ._split ]
42
+ data_file = base_folder / file_name
43
+ if not check_integrity (str (data_file ), md5 = md5 ):
44
+ raise RuntimeError (
45
+ f"{ file_name } not found in { base_folder } or corrupted. "
46
+ f"You can download it from "
47
+ f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
48
+ )
49
+
50
+ with open (data_file , "r" , newline = "" ) as file :
42
51
self ._samples = [
43
52
(
44
53
torch .tensor ([int (idx ) for idx in row ["pixels" ].split ()], dtype = torch .uint8 ).reshape (48 , 48 ),
@@ -62,17 +71,5 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
62
71
63
72
return image , target
64
73
65
- def _verify_integrity (self ):
66
- base_folder = os .path .join (self .root , "fer2013" )
67
- file_name , md5 = self ._RESOURCES [self ._split ]
68
- file = os .path .join (base_folder , file_name )
69
- if not check_integrity (file , md5 = md5 ):
70
- raise RuntimeError (
71
- f"{ file_name } not found in { base_folder } or corrupted. "
72
- f"You can download it from "
73
- f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
74
- )
75
- return file
76
-
77
74
def extra_repr (self ) -> str :
78
75
return f"split={ self ._split } "
0 commit comments