Skip to content

Commit 8ac9e34

Browse files
author
Nathan Salberg
committed
Updated test cases for the fer2013 dataset
1 parent fc60b59 commit 8ac9e34

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

test/test_datasets.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,22 +2443,26 @@ def inject_fake_data(self, tmpdir, config):
24432443
os.makedirs(base_folder)
24442444

24452445
num_samples = 5
2446-
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
2446+
with open(os.path.join(base_folder, "icml_face_data.csv"), "w", newline="") as file:
24472447
writer = csv.DictWriter(
24482448
file,
2449-
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
2449+
fieldnames=("emotion", "pixels","Usage"),
24502450
quoting=csv.QUOTE_NONNUMERIC,
24512451
quotechar='"',
24522452
)
24532453
writer.writeheader()
2454-
for _ in range(num_samples):
2454+
for i in range(num_samples):
24552455
row = dict(
24562456
pixels=" ".join(
24572457
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
24582458
)
24592459
)
2460-
if config["split"] == "train":
2461-
row["emotion"] = str(int(torch.randint(0, 7, ())))
2460+
row["emotion"] = str(int(torch.randint(0, 7, ())))
2461+
2462+
if config["split"] == "test":
2463+
row["Usage"] = "PublicTest" if i % 2 == 0 else "PrivateTest"
2464+
else:
2465+
row["Usage"] = "Training"
24622466

24632467
writer.writerow(row)
24642468

torchvision/datasets/fer2013.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def __init__(
5151
reader = csv.DictReader(file)
5252
self._samples = []
5353
for row in reader:
54-
cleaned_row = {name.strip(): value for name, value in row.items()}
55-
if self._split in cleaned_row["Usage"].lower():
54+
cleaned_row = {name.strip().lower(): value for name, value in row.items()}
55+
if self._split in cleaned_row["usage"].lower():
5656
self._samples.append(
5757
(
5858
torch.tensor(

0 commit comments

Comments
 (0)