Skip to content

Commit 0325fdd

Browse files
authored
Return labels for FER2013 if possible (#8452)
1 parent ab0b9a4 commit 0325fdd

File tree

2 files changed

+111
-26
lines changed

2 files changed

+111
-26
lines changed

test/test_datasets.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,28 +2442,68 @@ def inject_fake_data(self, tmpdir, config):
24422442
base_folder = os.path.join(tmpdir, "fer2013")
24432443
os.makedirs(base_folder)
24442444

2445+
use_icml = config.pop("use_icml", False)
2446+
use_fer = config.pop("use_fer", False)
2447+
24452448
num_samples = 5
2446-
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
2447-
writer = csv.DictWriter(
2448-
file,
2449-
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
2450-
quoting=csv.QUOTE_NONNUMERIC,
2451-
quotechar='"',
2452-
)
2453-
writer.writeheader()
2454-
for _ in range(num_samples):
2455-
row = dict(
2456-
pixels=" ".join(
2457-
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
2458-
)
2449+
2450+
if use_icml or use_fer:
2451+
pixels_key, usage_key = (" pixels", " Usage") if use_icml else ("pixels", "Usage")
2452+
fieldnames = ("emotion", usage_key, pixels_key) if use_icml else ("emotion", pixels_key, usage_key)
2453+
filename = "icml_face_data.csv" if use_icml else "fer2013.csv"
2454+
with open(os.path.join(base_folder, filename), "w", newline="") as file:
2455+
writer = csv.DictWriter(
2456+
file,
2457+
fieldnames=fieldnames,
2458+
quoting=csv.QUOTE_NONNUMERIC,
2459+
quotechar='"',
24592460
)
2460-
if config["split"] == "train":
2461-
row["emotion"] = str(int(torch.randint(0, 7, ())))
2461+
writer.writeheader()
2462+
for i in range(num_samples):
2463+
row = {
2464+
"emotion": str(int(torch.randint(0, 7, ()))),
2465+
usage_key: "Training" if i % 2 else "PublicTest",
2466+
pixels_key: " ".join(
2467+
str(pixel)
2468+
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
2469+
),
2470+
}
2471+
2472+
writer.writerow(row)
2473+
else:
2474+
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
2475+
writer = csv.DictWriter(
2476+
file,
2477+
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
2478+
quoting=csv.QUOTE_NONNUMERIC,
2479+
quotechar='"',
2480+
)
2481+
writer.writeheader()
2482+
for _ in range(num_samples):
2483+
row = dict(
2484+
pixels=" ".join(
2485+
str(pixel)
2486+
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
2487+
)
2488+
)
2489+
if config["split"] == "train":
2490+
row["emotion"] = str(int(torch.randint(0, 7, ())))
24622491

2463-
writer.writerow(row)
2492+
writer.writerow(row)
24642493

24652494
return num_samples
24662495

2496+
def test_icml_file(self):
2497+
config = {"split": "test"}
2498+
with self.create_dataset(config=config) as (dataset, _):
2499+
assert all(s[1] is None for s in dataset)
2500+
2501+
for split in ("train", "test"):
2502+
for d in ({"use_icml": True}, {"use_fer": True}):
2503+
config = {"split": split, **d}
2504+
with self.create_dataset(config=config) as (dataset, _):
2505+
assert all(s[1] is not None for s in dataset)
2506+
24672507

24682508
class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
24692509
DATASET_CLASS = datasets.GTSRB

torchvision/datasets/fer2013.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,21 @@ class FER2013(VisionDataset):
1313
"""`FER2013
1414
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
1515
16+
.. note::
17+
This dataset can return test labels only if ``fer2013.csv`` OR
18+
``icml_face_data.csv`` are present in ``root/fer2013/``. If only
19+
``train.csv`` and ``test.csv`` are present, the test labels are set to
20+
``None``.
21+
1622
Args:
1723
root (str or ``pathlib.Path``): Root directory of dataset where directory
18-
``root/fer2013`` exists.
24+
``root/fer2013`` exists. This directory may contain either
25+
``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
26+
``test.csv``. Precendence is given in that order, i.e. if
27+
``fer2013.csv`` is present then the rest of the files will be
28+
ignored. All these (combinations of) files contain the same data and
29+
are supported for convenience, but only ``fer2013.csv`` and
30+
``icml_face_data.csv`` are able to return non-None test labels.
1931
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
2032
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2133
version. E.g, ``transforms.RandomCrop``
@@ -25,6 +37,25 @@ class FER2013(VisionDataset):
2537
_RESOURCES = {
2638
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
2739
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
40+
# The fer2013.csv and icml_face_data.csv files contain both train and
41+
# tests instances, and unlike test.csv they contain the labels for the
42+
# test instances. We give these 2 files precedence over train.csv and
43+
# test.csv. And yes, they both contain the same data, but with different
44+
# column names (note the spaces) and ordering:
45+
# $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
46+
# ==> fer2013.csv <==
47+
# emotion,pixels,Usage
48+
#
49+
# ==> icml_face_data.csv <==
50+
# emotion, Usage, pixels
51+
#
52+
# ==> train.csv <==
53+
# emotion,pixels
54+
#
55+
# ==> test.csv <==
56+
# pixels
57+
"fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
58+
"icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
2859
}
2960

3061
def __init__(
@@ -34,11 +65,13 @@ def __init__(
3465
transform: Optional[Callable] = None,
3566
target_transform: Optional[Callable] = None,
3667
) -> None:
37-
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
68+
self._split = verify_str_arg(split, "split", ("train", "test"))
3869
super().__init__(root, transform=transform, target_transform=target_transform)
3970

4071
base_folder = pathlib.Path(self.root) / "fer2013"
41-
file_name, md5 = self._RESOURCES[self._split]
72+
use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
73+
use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
74+
file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
4275
data_file = base_folder / file_name
4376
if not check_integrity(str(data_file), md5=md5):
4477
raise RuntimeError(
@@ -47,14 +80,26 @@ def __init__(
4780
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
4881
)
4982

83+
pixels_key = " pixels" if use_icml_file else "pixels"
84+
usage_key = " Usage" if use_icml_file else "Usage"
85+
86+
def get_img(row):
87+
return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
88+
89+
def get_label(row):
90+
if use_fer_file or use_icml_file or self._split == "train":
91+
return int(row["emotion"])
92+
else:
93+
return None
94+
5095
with open(data_file, "r", newline="") as file:
51-
self._samples = [
52-
(
53-
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
54-
int(row["emotion"]) if "emotion" in row else None,
55-
)
56-
for row in csv.DictReader(file)
57-
]
96+
rows = (row for row in csv.DictReader(file))
97+
98+
if use_fer_file or use_icml_file:
99+
valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
100+
rows = (row for row in rows if row[usage_key] in valid_keys)
101+
102+
self._samples = [(get_img(row), get_label(row)) for row in rows]
58103

59104
def __len__(self) -> int:
60105
return len(self._samples)

0 commit comments

Comments
 (0)