@@ -13,9 +13,21 @@ class FER2013(VisionDataset):
13
13
"""`FER2013
14
14
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
15
15
16
+ .. note::
17
+ This dataset can return test labels only if ``fer2013.csv`` OR
18
+ ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
19
+ ``train.csv`` and ``test.csv`` are present, the test labels are set to
20
+ ``None``.
21
+
16
22
Args:
17
23
root (str or ``pathlib.Path``): Root directory of dataset where directory
18
- ``root/fer2013`` exists.
24
+ ``root/fer2013`` exists. This directory may contain either
25
+ ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
26
+ ``test.csv``. Precendence is given in that order, i.e. if
27
+ ``fer2013.csv`` is present then the rest of the files will be
28
+ ignored. All these (combinations of) files contain the same data and
29
+ are supported for convenience, but only ``fer2013.csv`` and
30
+ ``icml_face_data.csv`` are able to return non-None test labels.
19
31
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
20
32
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
21
33
version. E.g, ``transforms.RandomCrop``
@@ -25,6 +37,25 @@ class FER2013(VisionDataset):
25
37
_RESOURCES = {
26
38
"train" : ("train.csv" , "3f0dfb3d3fd99c811a1299cb947e3131" ),
27
39
"test" : ("test.csv" , "b02c2298636a634e8c2faabbf3ea9a23" ),
40
+ # The fer2013.csv and icml_face_data.csv files contain both train and
41
+ # tests instances, and unlike test.csv they contain the labels for the
42
+ # test instances. We give these 2 files precedence over train.csv and
43
+ # test.csv. And yes, they both contain the same data, but with different
44
+ # column names (note the spaces) and ordering:
45
+ # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
46
+ # ==> fer2013.csv <==
47
+ # emotion,pixels,Usage
48
+ #
49
+ # ==> icml_face_data.csv <==
50
+ # emotion, Usage, pixels
51
+ #
52
+ # ==> train.csv <==
53
+ # emotion,pixels
54
+ #
55
+ # ==> test.csv <==
56
+ # pixels
57
+ "fer" : ("fer2013.csv" , "f8428a1edbd21e88f42c73edd2a14f95" ),
58
+ "icml" : ("icml_face_data.csv" , "b114b9e04e6949e5fe8b6a98b3892b1d" ),
28
59
}
29
60
30
61
def __init__ (
@@ -34,11 +65,13 @@ def __init__(
34
65
transform : Optional [Callable ] = None ,
35
66
target_transform : Optional [Callable ] = None ,
36
67
) -> None :
37
- self ._split = verify_str_arg (split , "split" , self . _RESOURCES . keys ( ))
68
+ self ._split = verify_str_arg (split , "split" , ( "train" , "test" ))
38
69
super ().__init__ (root , transform = transform , target_transform = target_transform )
39
70
40
71
base_folder = pathlib .Path (self .root ) / "fer2013"
41
- file_name , md5 = self ._RESOURCES [self ._split ]
72
+ use_fer_file = (base_folder / self ._RESOURCES ["fer" ][0 ]).exists ()
73
+ use_icml_file = not use_fer_file and (base_folder / self ._RESOURCES ["icml" ][0 ]).exists ()
74
+ file_name , md5 = self ._RESOURCES ["fer" if use_fer_file else "icml" if use_icml_file else self ._split ]
42
75
data_file = base_folder / file_name
43
76
if not check_integrity (str (data_file ), md5 = md5 ):
44
77
raise RuntimeError (
@@ -47,14 +80,26 @@ def __init__(
47
80
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
48
81
)
49
82
83
+ pixels_key = " pixels" if use_icml_file else "pixels"
84
+ usage_key = " Usage" if use_icml_file else "Usage"
85
+
86
+ def get_img (row ):
87
+ return torch .tensor ([int (idx ) for idx in row [pixels_key ].split ()], dtype = torch .uint8 ).reshape (48 , 48 )
88
+
89
+ def get_label (row ):
90
+ if use_fer_file or use_icml_file or self ._split == "train" :
91
+ return int (row ["emotion" ])
92
+ else :
93
+ return None
94
+
50
95
with open (data_file , "r" , newline = "" ) as file :
51
- self . _samples = [
52
- (
53
- torch . tensor ([ int ( idx ) for idx in row [ "pixels" ]. split ()], dtype = torch . uint8 ). reshape ( 48 , 48 ),
54
- int ( row [ "emotion" ] ) if "emotion" in row else None ,
55
- )
56
- for row in csv . DictReader ( file )
57
- ]
96
+ rows = ( row for row in csv . DictReader ( file ))
97
+
98
+ if use_fer_file or use_icml_file :
99
+ valid_keys = ( "Training" , ) if self . _split == "train" else ( "PublicTest" , "PrivateTest" )
100
+ rows = ( row for row in rows if row [ usage_key ] in valid_keys )
101
+
102
+ self . _samples = [( get_img ( row ), get_label ( row )) for row in rows ]
58
103
59
104
def __len__ (self ) -> int :
60
105
return len (self ._samples )
0 commit comments